diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..48b75c16 --- /dev/null +++ b/.clang-format @@ -0,0 +1,60 @@ +# clang-format configuration for LLTFI +# +# Based on LLVM style (K&R braces, 2-space indent, 80-col limit). +# Run: clang-format-20 -i +# Or: lint.sh --fix (reformats all C++ files in-place) + +--- +BasedOnStyle: LLVM +Language: Cpp +Standard: c++17 + +# Indentation +IndentWidth: 2 +TabWidth: 2 +UseTab: Never +ContinuationIndentWidth: 4 +IndentCaseLabels: false +IndentPPDirectives: None + +# Line length +ColumnLimit: 80 + +# Line break +LineEnding: LF + +# Braces — attach (K&R / LLVM style) +BreakBeforeBraces: Attach + +# Includes +SortIncludes: CaseSensitive +IncludeBlocks: Regroup +IncludeCategories: + # LLTFI local headers first + - Regex: '^"(FI|Utils|Controller|Profil|Inst|Reg|Gen|LLFIDot)' + Priority: 1 + # LLVM headers + - Regex: '^"llvm/' + Priority: 2 + # System headers + - Regex: '^<' + Priority: 3 + +# Pointer alignment: right (int *p, not int* p) +PointerAlignment: Right + +# Function arguments +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +BinPackArguments: true +BinPackParameters: true + +# Short constructs — keep consistent with LLVM defaults +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false + +# Misc +SpaceBeforeParens: ControlStatements +SpacesInAngles: Never +SpacesInParentheses: false diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 00000000..cb1e74f7 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,62 @@ +# clang-tidy configuration for LLTFI +# +# Run from the repo root (requires compile_commands.json in the build dir): +# clang-tidy-20 -p /path/to/LLTFI-build +# Or use lint.sh which handles discovery automatically. +# +# To generate compile_commands.json, add -DCMAKE_EXPORT_COMPILE_COMMANDS=ON +# to the cmake invocation inside ./setup, or run: +# cd /path/to/LLTFI-build && cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON . + +--- +Checks: > + bugprone-*, + clang-analyzer-core.*, + clang-analyzer-cplusplus.*, + clang-analyzer-deadcode.*, + clang-analyzer-security.*, + cppcoreguidelines-init-variables, + cppcoreguidelines-narrowing-conversions, + cppcoreguidelines-slicing, + llvm-namespace-comment, + llvm-twine-local, + modernize-use-nullptr, + modernize-use-override, + modernize-redundant-void-arg, + performance-for-range-copy, + performance-implicit-conversion-in-loop, + performance-unnecessary-copy-initialization, + readability-const-return-type, + readability-container-size-empty, + readability-delete-null-pointer, + readability-misplaced-array-index, + readability-redundant-declaration, + -bugprone-easily-swappable-parameters, + -bugprone-macro-parentheses, + -bugprone-branch-clone, + -bugprone-assignment-in-if-condition, + -clang-analyzer-optin.*, + -clang-analyzer-cplusplus.NewDelete, + -clang-diagnostic-macro-redefined, + +# Checks disabled because they fire heavily on legitimate LLVM-pass idioms: +# cppcoreguidelines-pro-bounds-* — pointer arithmetic is normal in LLVM IR +# cppcoreguidelines-pro-type-* — reinterpret_cast used in IR manipulation +# cppcoreguidelines-avoid-magic-numbers — opcode numbers are intentional +# cppcoreguidelines-pro-type-vararg — some C APIs use varargs +# modernize-use-trailing-return-type — not in our style guide + +WarningsAsErrors: '' + +HeaderFilterRegex: 'llvm_passes/.*\.h$' + +CheckOptions: + # Enforce nullptr over NULL + - key: modernize-use-nullptr.NullMacros + value: 'NULL' + + # Namespace comment style: closing brace should say "// namespace llfi" + - key: llvm-namespace-comment.ShortNamespaceLines + value: '10' + - key: llvm-namespace-comment.SpacesBeforeComments + value: '2' diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..57e92328 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,26 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.{cpp,h,c}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.{yaml,yml}] +indent_style = space +indent_size = 2 + +[CMakeLists.txt] +indent_style = space +indent_size = 2 + +[Makefile] +indent_style = tab diff --git a/.gitignore b/.gitignore index 414915c1..afc35fd6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,17 +6,7 @@ config/java_paths.cmake config/java_paths.py tools/zgrviewer/llfi_run.sh -# Same holds for the installer directory -installer/downloads/* -installer/llfi/* -installer/llfisrc/* -installer/llvm/* -installer/llvmsrc/* -installer/pyyaml/* -installer/pyyamlsrc/* -installer/fontconfig/* build/ -sample_programs/ # All the .ll and .bc files *.ll @@ -28,8 +18,9 @@ sample_programs/ __pycache__/ *.pyc -# FIDL-generated software fault selector files (regenerated by setup via FIDL-Algorithm.py) -llvm_passes/software_failures/_*_*Selector.cpp +# protobuf build artifacts (downloaded during ML dependency setup) +protobuf-*/ +protobuf-all-*.zip # gedit backup files *~ @@ -56,7 +47,3 @@ gui/sum/ gui/min/ gui/bfs/ -# web app -web-app/node_modules/ -web-app/views/bundle.min.js -web-app/server/uploads/* diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..1bbf0014 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,160 @@ +# Changelog + +All notable changes to LLTFI are recorded here. +Format follows [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +--- + +## [Unreleased] — LLVM 20 branch (`LLVM20`) + +This release upgrades LLTFI from LLVM 15 to LLVM 20. Every change is +backward-compatible with LLFI. The full migration narrative, task breakdown, +and effort accounting are in `migration.md`. + +### Breaking changes + +- **LLVM version requirement raised from 15 to 20.** LLVM 15 is no longer + supported. Install LLVM 20 via the LLVM apt repository or build from source + (see `README.md`). +- **Legacy pass manager (`opt -load`, `-enable-new-pm=0`) removed.** All + passes — including `InstructionDuplication` — now use the new pass manager + exclusively. Any external scripts calling `opt` directly must be updated to + use `-load-pass-plugin` and `--passes=`. +- **`InstructionDuplication` pass renamed** to `InstructionDuplicationPass` in + the plugin registry to match the new PM convention. + +--- + +### Added + +#### New passes and tests +- `InstructionDuplication` migrated to the new pass manager (`PassInfoMixin`); + exposed as `"InstructionDuplicationPass"` in `SEDPasses.so`. +- Two new tests in `test_instruction_duplication.py`: + - `real_model_structural` — applies `InstructionDuplicationPass` to a real + onnx-mlir `model.ll` and verifies `compareFloatValues` calls are inserted. + - `real_model_end_to_end` — runs the baseline and duplicated models through + `lli` and asserts outputs are identical (SKIP when `model.ll` absent). + +#### Tooling +- `lint.sh` — unified C++ and Python lint runner; `--fix` auto-formats in-place. +- `.clang-tidy` — project-level tidy config with intentionally disabled checks + documented. +- `.clang-format` — project-level format config (LLVM style, 2-space indent). +- `setup.cfg` — `flake8` and `flake8-bugbear` configuration for Python linting. + +#### Documentation +- `architecture.md` — new developer reference covering pass pipeline, selector + class hierarchy, hardware/software/ML fault modes, the runtime library, and + the interface between compile-time and runtime layers. +- `docs/input_yaml_guide.md` — prose guide to writing `input.yaml` files, + covering all keys, `CustomTensorOperator` ML targeting, and complete examples. +- `docs/tutorial_first_experiment.md` — end-to-end walkthrough of the + `factorial` experiment including output file interpretation and outcome + classification (masked / SDC / crash / hang). +- `docs/adding_a_test.md` — step-by-step guide for adding a regression test, + covering program registration, test case structure, custom Python scripts, + and the SKIP convention. +- `CODING_GUIDELINES.md` — expanded with sections on `override`, variable + initialisation, container emptiness (`.empty()` over `.size() == 0`), and + `cast<>` vs. `dyn_cast<>`. +- `CONTRIBUTING.md` — added `Adding a Test Case` section pointing to + `docs/adding_a_test.md`. +- `docs/tutorial_ml_experiment.md` — new end-to-end walkthrough of an ML/ONNX + fault injection experiment covering the full ONNX → LLVM IR compilation + pipeline, `CustomTensorOperator` layer targeting, multi-fault injection + options, per-layer profiling output, and `CompareLayerOutputs.py`. + +--- + +### Changed + +#### LLVM 20 API compatibility + +| File | Change | +|------|--------| +| `llvm_passes/core/FaultInjectionPass.cpp` | 3 sites: `new AllocaInst/StoreInst/LoadInst` constructors updated to LLVM 17+ API | +| `llvm_passes/core/InstTracePass.cpp` | 6 sites: same; `getFirstNonPHIOrDbgOrLifetime()` now returns `BasicBlock::iterator` | +| `llvm_passes/core/Utils.cpp` | `M.getGlobalList().push_back()` → `new GlobalVariable(M, ...)` (removed in LLVM 17) | +| All selector `.cpp` files | `getNumArgOperands()` → `arg_size()` (removed in LLVM 15); `#include "llvm/Support/CFG.h"` → `"llvm/IR/CFG.h"` | +| `llvm_passes/instruction_duplication/InstructionDuplication.cpp` | `getNextNonDebugInstruction()` return type updated to `BasicBlock::iterator` | + +#### Code quality (C++ — found by clang-tidy and code review) + +| Category | Details | +|----------|---------| +| Bug fixes | Double-free in `Controller.cpp` destructor; file stream leak in `LLFIDotGraphPass.cpp`; unchecked `fopen` null in `GenLLFIIndexPass.cpp`; uninitialized `isChainDuplication` field | +| Null safety | `getCalledFunction()` null checks in `ProfilingPass.cpp`, `InstructionDuplication.cpp`, `CustomTensorOperatorInstSelector.cpp` | +| LLVM idioms | `dyn_cast<>` after `isa<>` → `cast<>` (asserting) across `Utils.cpp`, `ProfilingPass.cpp`; `NULL` → `nullptr` throughout | +| Override safety | `virtual` on override methods → `override` keyword across all selector classes; `virtual ~Base() = default` added to abstract base classes | +| Style | `.empty()` over `.size() == 0`; `const auto&` in range-for; `strncpy`/`strncat` over unbounded `strcpy`/`strcat`; `cl::opt::getValue()` to avoid slicing | +| Dead code | Removed unreachable `return false` after exhaustive if/else in `InstructionDuplication.cpp:runOnMainGraph()` | +| Copies | `for (auto insVector : arithInst)` → `for (const auto& insVector : ...)` to avoid copying inner vectors | + +#### Code quality (Python — found by flake8/bugbear) + +- `except:` → `except Exception:` throughout `bin/`, `tools/`, `test_suite/SCRIPTS/` +- Bare `open()` → `with open(...) as f:` in multiple scripts +- `subprocess(..., shell=True)` removed; replaced with list-form calls +- `yaml.load()` → `yaml.safe_load()` everywhere +- `exit()` → `sys.exit()` in scripts +- `%-format` strings → f-strings in new code + +#### Docker + +- `docker/Dockerfile` — LLVM source checkout updated from a pinned LLVM 15 + commit hash (`9778ec057cf4`) to the `llvmorg-20.1.0` tag; `pyyaml===5.4.1` + corrected to `pyyaml==5.4.1` (non-standard triple-equals syntax). + +#### Documentation updates + +- `README.md` — restructured to eliminate overlap with `architecture.md`; + added `docs/` section listing all user guides; added pointer to + `architecture.md` for internal design. +- `caveats.txt` — LLVM version references updated 15 → 20; duplicate item + number fixed. +- `llvm_passes/instruction_duplication/README.md` — `opt -always-inline` + (legacy PM) → `opt --passes=always-inline` (new PM). +- `llvm_passes/instruction_duplication/shared_lib/build.sh` and + `compile_shrd_lib.sh` — hardcoded `clang`/`clang++` → `LLVM_GXX_BIN_DIR` + pattern, fixing builds on Ubuntu where apt installs `clang-20` only. +- `architecture.md` — corrected several inaccuracies found during code review: + `preFunc` return type (`bool` not `int`) and parameter types (`unsigned` + throughout); `injectFunc` register parameter types; `doProfiling` parameter + type (`int` not `unsigned`); `printInstTracer` signature (second param is + `char *opcode`, not `unsigned`; last param is `int`, not `long`); + `lltfiMLLayer` parameter types (`int64_t`); removed non-existent `random` + and `data_corruption` `fi_type` entries; corrected claim that + +--- + +### Pending before merge (human tasks) + +- **H-2** — Human review of IRBuilder insertion-point correctness in + `FaultInjectionPass.cpp` and `InstTracePass.cpp`. The `AllocaInst` calls + were migrated to `BasicBlock*` insertAtEnd form and `BasicBlock::iterator` + form respectively; both compile and all 21 tests pass, but a developer + familiar with the pass semantics should verify the insertion points are + logically correct before merging to main. +- **H-3** — onnx-mlir real-model validation. Requires installing onnx-mlir and + running `sample_programs/ml_sample_programs/vision_models/mnist/compile.sh` + to produce `model.ll`. The two new `test_instruction_duplication.py` tests + will then run instead of skipping. Not a blocker — all other tests pass. + +--- + +## [Previous] — LLVM 15 baseline (`master`) + +The `master` branch represents LLTFI as it existed targeting LLVM 15, with +the following improvements over the original LLFI fork: + +- ML fault injection support (TensorFlow, PyTorch via ONNX-MLIR) +- `CustomTensorOperator` instruction selector for layer-level ML targeting +- `InstructionDuplication` pass (`SEDPasses.so`) for soft-error detection +- Batch fault injection scripts (`batchInstrument.py`, `batchProfile.py`, + `batchInjectfault.py`) +- Trace analysis tools (`tracediff.py`, `traceontograph.py`, `traceunion.py`, + `tracetodot.py`) +- Makefile generation tool (`GenerateMakefile`) +- Initial `CODING_GUIDELINES.md` and `CONTRIBUTING.md` +- Migration plan document (`migration.md`) for the LLVM 15 → 20 upgrade diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..51166adc --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,176 @@ +# LLTFI — Claude Context + +## What this project is + +LLTFI (Low-Level Tensor Fault Injector) is an LLVM-based fault injection framework supporting C/C++ and ML applications (TensorFlow, PyTorch). It injects faults into LLVM IR. It is built on top of LLFI and is fully backward compatible with it. + +--- + +## Build setup (this machine) + +| Variable | Path | +|----------|------| +| Source tree | `/home/karthik/Programs/LLTFI` | +| Build root | `/home/karthik/Programs/LLTFI-build` | +| LLVM DST root | `/usr/lib/llvm-20` (apt install) | +| LLVM SRC root | `/home/karthik/Programs/llvm-project` | +| LLVM version | 20.1 | + +To rebuild from source: +```bash +./setup -LLFI_BUILD_ROOT /home/karthik/Programs/LLTFI-build \ + -LLVM_SRC_ROOT /home/karthik/Programs/llvm-project \ + -LLVM_DST_ROOT /usr/lib/llvm-20 \ + -LLVM_GXX_BIN_DIR /usr/lib/llvm-20/bin +``` + +The build root must not already exist. Delete it first if rebuilding from scratch. + +To rebuild after code changes (faster): +```bash +cd /home/karthik/Programs/LLTFI-build && make +``` + +--- + +## Running the test suite + +From the **build** directory: +```bash +cd /home/karthik/Programs/LLTFI-build/test_suite +python3 SCRIPTS/llfi_test --all_cpp # all 21 tests +python3 SCRIPTS/llfi_test --all_hardware_faults # 8 tests +python3 SCRIPTS/llfi_test --all_trace_tools_tests # 3 tests +python3 SCRIPTS/llfi_test --all_makefile_generation # 2 tests +python3 SCRIPTS/llfi_test --all_ml # ML/ONNX tools +``` + +Expected: **21/21 PASS** for `--all`. Some error messages during fault injection runs are normal. + +`--all_ml` runs additional tests not included in `--all`: + +| Test group | Tests | Requirements | +|------------|-------|--------------| +| `CompareLayerOutputs` | 2 | `pip install onnx pygraphviz` | +| `ExtendONNXModel` | 1 | `pip install onnx` | +| `outputONNXGraph` | 1 | `pip install onnx pydot` | +| TensorFlow → ONNX | 3 | `pip install tensorflow tf2onnx onnx` | +| PyTorch → ONNX | 2 | `pip install torch onnx` | +| ONNX → LLVM IR | 2 | `onnx-mlir` + `mlir-translate` on PATH or `$ONNX_MLIR_BUILD` set | +| Fault injection (ML) | 3 | LLTFI build + `model.ll` in `sample_programs/.../mnist/` (run `compile.sh` first) | + +Tests with missing deps are reported as **SKIP** (not FAIL) and excluded from the pass/fail count. + +The ONNX→IR and fault injection tests use the pre-built `model.onnx` from +`sample_programs/ml_sample_programs/vision_models/mnist/`. The fault injection +test also requires `model.ll` which is produced by that directory's `compile.sh`. + +--- + +## Key architecture + +``` +llvm_passes/ LLVM pass plugins (compiled to llfi-passes.so) + core/ Fault injection, profiling, tracing passes + hardware_failures/ Built-in hardware fault selectors (bitflip, funcname, etc.) +runtime_lib/ Runtime library (libllfi-rt.so) linked into instrumented binaries +bin/ Python driver scripts: instrument.py, profile.py, injectfault.py +docs/ input_masterlist.yaml, input_masterlist_ml.yaml — reference schemas + for the input.yaml files that control instrumentation and injection + input_yaml_guide.md — prose guide to writing input.yaml (user-facing) +tools/ Trace analysis tools (tracediff.py, traceontograph.py, traceunion.py, + tracetodot.py), GenerateMakefile/ +test_suite/ Regression tests + PROGRAMS/ Source programs used by tests + HardwareFaults/ Hardware fault injection test cases + Traces/ Pre-committed trace reference files for trace tool tests + MakefileGeneration/ Makefile generation test cases +``` + +--- + +## LLVM 19 API constraints + +The codebase targets **LLVM 19**. Key API changes to keep in mind: + +- `#include "llvm/IR/CFG.h"` (not `llvm/Support/CFG.h` — removed in LLVM 15) +- `CI->arg_size()` (not `CI->getNumArgOperands()` — removed in LLVM 15) +- `func->getName().str()` returns `StringRef`; call `.str()` before assigning to `std::string` +- `opt -load-pass-plugin` (not `opt -load` — legacy PM removed in LLVM 17) +- `--passes=passname` (not `-passname` — old opt syntax removed in LLVM 17) +- `InsertPosition` requires a `BasicBlock::iterator`, not a raw `Instruction*` — call `.getIterator()` on the insertion point +- `getFirstNonPHIOrDbgOrLifetime()` now returns `BasicBlock::iterator` (not `Instruction*`) +- `getFirstNonPHI()` → `getFirstNonPHIIt()` (returns iterator) +- `M.getGlobalList()` is private — use `new GlobalVariable(M, type, ...)` to insert directly +- `itaniumDemangle(str)` takes a single `string_view` argument (old 4-arg form removed) +- All passes use the **new pass manager** (`PassInfoMixin`, `llvmGetPassPluginInfo`); `InstructionDuplication` exposes both a legacy PM class and a new PM wrapper (`NewInstructionDuplicationPass`) registered as `"InstructionDuplicationPass"` in `SEDPasses.so` + +--- + +## Test suite — untracked files that should not be committed + +The following files appear as untracked after running tests and should not be staged: +- `test_suite/HardwareFaults/*/inp.in`, `graph_input.dat` — deployed by `deploy_prog.py` from `PROGRAMS/` +- `test_suite/HardwareFaults/*/llfi.test.log.*` — test run artifacts +- `test_suite/Traces/*/llfi/llfi_stat_output/*.report.txt` — generated by trace analysis + +--- + +## Linting + +Run from the source tree root: + +```bash +bash lint.sh # check only +bash lint.sh --fix # auto-fix clang-format issues in-place +bash lint.sh --cpp # C++ checks only +bash lint.sh --python # Python checks only +``` + +**Requirements:** +- C++: `clang-format-20` and `clang-tidy-20` (`apt install clang-format-20 clang-tidy-20`) +- C++ static analysis also needs `compile_commands.json` in the build root: + ```bash + cd /home/karthik/Programs/LLTFI-build && cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON . + ``` +- Python: `flake8` and `flake8-bugbear` (`pip install flake8 flake8-bugbear`) + +**clang-tidy config (`.clang-tidy`):** Checks include `modernize-use-override`, `readability-container-size-empty`, `cppcoreguidelines-init-variables`, `bugprone-*`, `clang-analyzer-core.*`, `performance-*`, and others. The following are intentionally disabled because they fire on legitimate LLVM patterns or system-header code: + +| Disabled check | Reason | +|---|---| +| `cppcoreguidelines-slicing` | `cl::opt` to `string` is idiomatic LLVM (use `.getValue()` instead) | +| `clang-analyzer-optin.*` | Fires on standard LLVM pass framework patterns | +| `clang-analyzer-cplusplus.NewDelete` | False positives from LLVM's internal memory management | +| `clang-diagnostic-macro-redefined` | Suppress `DEBUG_TYPE` conflicts with LLVM headers | +| `bugprone-assignment-in-if-condition` | `while ((pos = s.find(x)) != npos)` is idiomatic C++ | + +--- + +## Code style + +See `CODING_GUIDELINES.md` for the full style guide. Key points: + +**C++:** +- Use `nullptr`, not `NULL` +- Every header needs `#ifndef` include guards +- Missing `return` in `bool runOnModule(...)` is UB — always return `false` unless IR was modified +- Derived-class overrides: use `override`, omit `virtual`; abstract base classes need `virtual ~Base() = default;` +- Use `cast<>` (asserts) when type is guaranteed; `dyn_cast<>` (returns null) when it may not match +- Use `except Exception:` in Python; in C++ use `errs()` for pass diagnostics + +**Python:** +- Python 3 only (`#!/usr/bin/env python3`) +- Use `with open(...) as f:` always — no bare `open()` without context manager +- `except Exception:` minimum; use specific types (`OSError`, `KeyError`) where known +- `sys.exit()`, not `exit()` +- No `subprocess(..., shell=True)` — use list args and `stdout=open(file, 'w')` for redirection +- `yaml.safe_load()` always — never `yaml.load()` without Loader + +--- + +## Wiki known inconsistencies (external, not fixable here) + +- Wiki Page 4: says `input_masterlist.yaml` is in `bin/` — actually in `docs/` +- Wiki Page 5: uses `llvm-gcc` (replaced by clang) and old `opt -load` syntax +- Wiki Page 9: references `gui/config/` files that no longer exist (Java GUI removed) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ec89e9c..a09d7e5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.15) +cmake_minimum_required(VERSION 3.5) option(NO_GUI "Skip building GUI" OFF) diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md new file mode 100644 index 00000000..252500ef --- /dev/null +++ b/CODING_GUIDELINES.md @@ -0,0 +1,374 @@ +# LLTFI Coding Guidelines + +These guidelines apply to all hand-written source files in the project. C++ and C sources live under `llvm_passes/`, `runtime_lib/`, and `tools/`. Python sources live under `bin/`, `test_suite/SCRIPTS/`, `tools/`, and `setup`. + +--- + +## C++ Standard + +Target **C++17**. LLVM 20 requires it, and it provides range-based for loops, `nullptr`, `auto`, structured bindings, and other improvements used throughout the codebase. + +--- + +## Naming Conventions + +| Construct | Style | Example | +|-----------|-------|---------| +| Classes / structs | `PascalCase` | `FaultInjectionPass`, `FIInstSelector` | +| Functions / methods | `camelCase` | `getFIInsts()`, `isInstFITarget()` | +| Local variables | `snake_case` | `fi_inst`, `reg_pos_list` | +| Member variables | `snake_case` (no prefix) | `fi_rettype_funcname_map` | +| Constants / macros | `ALL_CAPS` | `DST_REG_POS`, `OPTION_LENGTH` | +| Namespaces | `lowercase` | `namespace llfi` | + +--- + +## Null Pointers + +Use `nullptr` in all C++ code. Never use `NULL` or `0` for pointer comparisons in C++ files. + +```cpp +// Good +Value *fi_reg = nullptr; +if (mainfunc == nullptr) { ... } + +// Bad +Value *fi_reg = NULL; +if (mainfunc == 0) { ... } +``` + +`NULL` remains acceptable in C files (`.c`). + +--- + +## Include Guards + +Every header file must have an include guard. Use `#ifndef` / `#define` / `#endif`. The guard name must match the filename (upper-cased, dots and slashes replaced with underscores). + +```cpp +// FaultInjectionPass.h +#ifndef FAULT_INJECTION_PASS_H +#define FAULT_INJECTION_PASS_H + +// ... contents ... + +#endif // FAULT_INJECTION_PASS_H +``` + +Do **not** use `#pragma once` — it is not part of the C++ standard and is inconsistent with the rest of the codebase. + +--- + +## Memory Management + +Prefer stack allocation or LLVM's ownership model over manual `new`/`delete`. + +```cpp +// Good — stack allocation +LegacyProfilingPass pass; +pass.runOnModule(M); + +// Also good — LLVM IR objects are owned by the module +new AllocaInst(type, addrSpace, "", insertBefore); // IR takes ownership + +// Bad — manual heap allocation for short-lived objects +auto *obj = new LegacyProfilingPass(); +obj->runOnModule(M); +delete obj; +``` + +When heap allocation is unavoidable outside of LLVM IR insertion, use `std::unique_ptr`. + +--- + +## Return Statements + +Every non-void function must have a return statement on all code paths. Missing returns in functions returning `bool` (e.g., `runOnModule`) are undefined behaviour and generate compiler warnings. + +```cpp +// Good +bool runOnModule(Module &M) override { + // ... do work ... + return false; // ModulePass: return true only if IR was modified +} +``` + +--- + +## Virtual Methods and `override` + +In derived classes, always mark overriding methods with `override` and omit the redundant `virtual` keyword. This catches signature mismatches at compile time and makes the override relationship explicit. + +```cpp +// Good — override makes the intent clear and catches errors +class MyInstSelector : public HardwareFIInstSelector { + bool isInstFITarget(Instruction *inst) override; + void getCompileTimeInfo(std::map &info) override; +}; + +// Bad — virtual is redundant; override is missing +class MyInstSelector : public HardwareFIInstSelector { + virtual bool isInstFITarget(Instruction *inst); + virtual void getCompileTimeInfo(std::map &info); +}; +``` + +Any abstract base class that may be deleted through a base pointer **must** declare a virtual destructor: + +```cpp +class FIInstSelector { +public: + virtual ~FIInstSelector() = default; // Required — subclass instances may be deleted via base ptr + virtual bool isInstFITarget(Instruction *inst) = 0; + ... +}; +``` + +--- + +## Variable Initialisation + +Initialise all local variables and pointers at the point of declaration. Uninitialised variables are undefined behaviour and are caught by the static analyser. + +```cpp +// Good +Instruction *insertPoint = nullptr; +float bitSize = 0.0f; +std::map *> *regs_map = nullptr; + +// Bad — UB if the assignment is missed on any code path +Instruction *insertPoint; +float bitSize; +``` + +--- + +## Container Emptiness + +Use `.empty()` to test whether a container has elements. Do not compare `.size()` to zero — it is less readable and some container types compute `size()` in O(n). + +```cpp +// Good +assert(!exitinsts.empty() && "Program has no exit point"); +if (reglist->empty()) return; + +// Bad +assert(exitinsts.size() != 0 && "Program has no exit point"); +if (reglist->size() == 0) return; +``` + +--- + +## LLVM Idioms + +Use LLVM's type-checking utilities instead of C-style casts: + +```cpp +// Good — dyn_cast when the type may not match; always check for null +CallInst *CI = dyn_cast(inst); +if (CI) { ... } + +// Good — cast<> when the type is guaranteed (e.g. after an opcode or isa<> check); +// it asserts on failure rather than returning null +if (inst->getOpcode() == Instruction::Call) { + CallInst *CI = cast(inst); // safe: opcode already verified + ... +} + +// Bad +if (dynamic_cast(inst)) { ... } +CallInst *CI = (CallInst*)inst; +``` + +Use `llvm::errs()` for diagnostic output from passes, not `std::cerr`: + +```cpp +// Good (in LLVM passes) +errs() << "ERROR: ...\n"; + +// Acceptable (in runtime C library) +fprintf(stderr, "ERROR: ...\n"); +``` + +Use `StringRef` and `.str()` when converting LLVM names to `std::string`: + +```cpp +std::string name = called_func->getName().str(); // Good +std::string name = called_func->getName(); // Bad: implicit conversion removed in LLVM 15 +``` + +--- + +## Error Handling + +- In LLVM passes: print a message with `errs()` and return early. Do not call `exit()`. +- In the runtime C library: `exit(1)` is acceptable for unrecoverable configuration errors (e.g., missing config file at startup), but not for conditions that can be recovered from. +- Always check the return value of `fopen` and similar system calls before using the result. + +```cpp +// Good +FILE *f = fopen(filename, "r"); +if (f == NULL) { + fprintf(stderr, "ERROR: cannot open %s\n", filename); + exit(1); +} +``` + +--- + +## Brace Style + +Opening brace on the **same line** as the declaration (K&R / LLVM style): + +```cpp +// Good +void FaultInjectionPass::insertInjectionFuncCall(...) { + for (...) { + if (...) { + } + } +} + +// Bad (Allman style — don't use) +void FaultInjectionPass::insertInjectionFuncCall(...) +{ +} +``` + +--- + +## Indentation + +Use **2 spaces**. Do not use tabs. Most editors can be configured to convert tabs to spaces automatically. + +--- + +## `using namespace` in Headers + +Avoid `using namespace llvm;` in header files. It pollutes the namespace of every translation unit that includes the header. Existing headers use it for historical reasons; do not add it to new headers. + +In `.cpp` files, `using namespace llvm;` is acceptable. + +--- + +## Comments + +- Use `//` for single-line and multi-line inline comments. +- Use the LLVM file banner (see `FaultInjectionPass.cpp`) for files that are part of the LLFI/LLTFI distribution. +- Remove contributor-specific annotations (e.g. `/*BEHROOZ: ...*/`, `//=== QINING @DATE ===`) before merging. Instead, document the *why* in a neutral comment or a commit message. +- Avoid commented-out code. Use version control instead. + +--- + +## Python Guidelines + +### Python Version + +All scripts must target **Python 3**. Use `#!/usr/bin/env python3` as the shebang. Do not use Python 2 syntax or shebangs. + +### Naming Conventions + +Follow **PEP 8** throughout: + +| Construct | Style | Example | +|-----------|-------|---------| +| Functions / methods | `snake_case` | `check_input_yaml()`, `parse_args()` | +| Variables | `snake_case` | `option_list`, `base_dir` | +| Classes | `PascalCase` | `DiffBlock`, `FaultReport` | +| Constants (module-level) | `UPPER_CASE` | `DEFAULT_TIMEOUT` | + +### Exception Handling + +Never use a bare `except:` — it silently swallows `KeyboardInterrupt` and `SystemExit`. Use the most specific exception type available; fall back to `except Exception:` only when the exact type is genuinely unknown. + +```python +# Good — specific +try: + with open(path, 'r') as f: + doc = yaml.safe_load(f) +except OSError: + print("ERROR: cannot open", path) + sys.exit(1) +except yaml.YAMLError: + print("ERROR: invalid YAML in", path) + sys.exit(1) + +# Acceptable fallback when type is unclear +except Exception: + ... + +# Bad — catches everything including Ctrl-C +except: + ... +``` + +### File Handling + +Always use `with` statements (context managers) for file operations. This guarantees the file is closed even if an exception is raised. + +```python +# Good +with open(filename, 'r') as f: + data = f.read() + +# Bad — file may not be closed on exception +f = open(filename, 'r') +data = f.read() +f.close() +``` + +### subprocess — Avoid `shell=True` + +Do not pass `shell=True` to `subprocess` functions unless strictly necessary. Use a list of arguments instead. For output redirection (which requires a shell when using `>`), pass `stdout=open(output_file, 'w')` to subprocess directly. + +```python +# Good +with open(output_file, 'w') as out: + subprocess.call([script, arg1, arg2], stdout=out, stderr=log) + +# Bad — shell injection risk, path-escaping burden +cmd = script + " " + arg1 + " > " + output_file +subprocess.call(cmd, shell=True) +``` + +### `sys.exit()` vs `exit()` + +Use `sys.exit()` in scripts. `exit()` is intended for the interactive interpreter and is not guaranteed to be available in all environments. + +```python +import sys +sys.exit(1) # Good +exit(1) # Bad +``` + +### String Formatting + +Prefer **f-strings** for new code (Python 3.6+). Avoid `%`-formatting in new code. + +```python +# Good +print(f"ERROR: No input.yaml in {srcpath}") + +# Acceptable for existing code +print("ERROR: No input.yaml in {}".format(srcpath)) + +# Avoid in new code +print("ERROR: No input.yaml in %s" % srcpath) +``` + +### Imports + +One import per line. Order: standard library → third-party → local. Separate each group with a blank line. + +```python +import os +import sys + +import yaml + +from config import llvm_paths +``` + +### `yaml.safe_load()` + +Always use `yaml.safe_load()`. Never use `yaml.load()` without an explicit `Loader` argument — it executes arbitrary Python and is a security risk. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..262fe1d7 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,89 @@ +# Contributing to LLTFI + +## Prerequisites + +Before writing any code, read: +- [CODING_GUIDELINES.md](CODING_GUIDELINES.md) — C++ and Python style rules +- [CLAUDE.md](CLAUDE.md) — project architecture, build setup, constraints +- [caveats.txt](caveats.txt) — known limitations and gotchas + +## Development Setup + +```bash +# Run from the repo root. The setup script runs cmake and make internally. +./setup -LLFI_BUILD_ROOT /path/to/LLTFI-build \ + -LLVM_SRC_ROOT /path/to/llvm-project \ + -LLVM_DST_ROOT /usr/lib/llvm-20 \ + -LLVM_GXX_BIN_DIR /usr/lib/llvm-20/bin +``` + +After code changes, rebuild without re-running setup: +```bash +cd /path/to/LLTFI-build && make +``` + +## Making Changes + +### C++ Changes + +- Target C++17; compile against LLVM 20 APIs +- Use `nullptr`, `#ifndef` include guards, RAII memory management +- LLVM 20 API: `arg_size()` (not `getNumArgOperands()`), `#include "llvm/IR/CFG.h"` (not `llvm/Support/CFG.h`), `InsertPosition` via `.getIterator()` for instruction insertion points, `--passes=name` (not `-name`) for opt invocation +- `runOnModule` must return `bool` + +### Python Changes + +- Python 3 only; follow PEP 8 +- Always use `with open(...)` for file I/O +- Never use `shell=True` in subprocess calls +- Use `sys.exit()`, not `exit()` +- Use `yaml.safe_load()`, not `yaml.load()` +- Minimum `except Exception:` — never bare `except:` + +## Adding a Test Case + +See `docs/adding_a_test.md` for a step-by-step guide to creating a new +regression test, including how to register a program, structure the test case +directory, and use the SKIP convention for optional dependencies. + +## Running Tests + +From the build directory: + +```bash +# All core tests (must pass before any PR) +./test_suite/SCRIPTS/llfi_test --all_cpp + +# Specific subsets +./test_suite/SCRIPTS/llfi_test --all_hardware_faults +./test_suite/SCRIPTS/llfi_test --all_trace_tools_tests + +# Single test case +./test_suite/SCRIPTS/llfi_test --test_cases +``` + +### ML/ONNX tests + +```bash +./test_suite/SCRIPTS/llfi_test --all_ml +``` + +These are not part of `--all` because they require optional dependencies. +Tests with missing deps report **SKIP**, not FAIL, and don't affect the pass/fail count. + +| Group | Requirements | +|-------|-------------| +| ML tool unit tests | `pip install onnx pygraphviz pydot` | +| TensorFlow → ONNX | `pip install tensorflow tf2onnx onnx` | +| PyTorch → ONNX | `pip install torch onnx` | +| ONNX → LLVM IR | onnx-mlir binary (`ONNX_MLIR_BUILD` env var) | +| Fault injection (ML) | LLTFI build + `model.ll` from `sample_programs/.../mnist/compile.sh` | + +All core tests (`--all_cpp`) must pass before submitting a pull request. ML tests (`--all_ml`) should pass for any change that touches ML-related code. + +## Pull Request Checklist + +- [ ] All tests pass (`llfi_test --all_cpp`) +- [ ] Code follows CODING_GUIDELINES.md +- [ ] If changing a public API or behavior: README.md is updated +- [ ] New functionality has a corresponding test case diff --git a/CREDITS.TXT b/CREDITS.TXT index 6f4a1095..e3658a03 100755 --- a/CREDITS.TXT +++ b/CREDITS.TXT @@ -45,7 +45,7 @@ D: Added Software Fault Injection to the GUI and integrated FIDL into LLFI. N: Abraham Chan D: Added Makefile script to support compilation of entire projects. -D: Upgraded LLVM from 3.4 to LLVM 12, added initial capability for FI on ML models +D: Upgraded LLVM from 3.4 to LLVM 12 to LLVM 19, added initial capability for FI on ML models N: Behrooz Sangchoolie D: Added support for multiple bit fault injections for hardware faults. diff --git a/LLFIWriteup.pdf b/LLFIWriteup.pdf deleted file mode 100755 index 35ca78b9..00000000 Binary files a/LLFIWriteup.pdf and /dev/null differ diff --git a/README.TXT b/README.TXT deleted file mode 100755 index a05ca737..00000000 --- a/README.TXT +++ /dev/null @@ -1,49 +0,0 @@ -LLVM Fault Injector - LLFI - Description: An LLVM Tool for fault injection, easily map between fault at IR and source level, configurable and extensible. - -====== -Pre-requisites - 1. CMake installed - 2. LLVM version 3.4, built with CMake - 3. Python 3 - 4. Python YAML library installed (PyYAML) - 5. clang-3.4 ( frontend for llvm 3.4 ) - 6. 64 bit Machines with 64 bit Linux/OS X. - -====== -Installation - A. Install CMake, Python, PyYAML library - - B. Install llvm-3.4 and clang 3.4 - 1. Go to "http://llvm.org/releases/download.html#3.4" to download LLVM source code and clang source code/binaries for your system. - 2. If building clang from source code, copy the source code under tools/. Access "http://llvm.org/releases/3.4/docs/GettingStarted.html#installcf" for instructions. - 2. Build llvm-3.4 ***WITH CMAKE*** 'using flag -DLLVM_REQUIRES_RTTI=1'. Access "http://llvm.org/docs/CMake.html" for instructions. - - C. Build LLFI - 1. Extract the code from LLFI archive (/LLFI) - 2. Go to /LLFI directory and run './setup --help' to see how to build LLFI to a different directory - - D. Testing LLFI - You can use example programs in /LLFI/test_suite/PROGRAMS/factorial to test LLFI. - - Example program: factorial - 1. Copy test_suite/factorial/ to your project directory and change to that directory. - 2. Build a single IR file with LLFI tool compiletoIR - /tools/GenerateMakefile --readable -o factorial.ll --all - 3. Instrument factorial with calls to LLFI libraries and create executables under llfi/ - /bin/instrument --readable factorial.ll - 4. Run factorial executable with profiling functions instrumented - /bin/profile llfi/factorial-profiling.exe 6 - In file llfi/baseline/golden_std_output, you should be able to see 720 - 5. Run factorial executable with fault injection functions instrumented - /bin/injectfault llfi/factorial-faultinjection.exe 6 - You should be able to see result files in llfi/std_output/, fault injection stats in llfi/prog_output/, failure report (crash/hang) in llfi/error_output/ - - For complete test of whole of LLFI, please use LLFI test suite and refer to wiki page: 'https://github.com/DependableSystemsLab/LLFI/wiki/Test-Suite-for-Regression-Test' for details. - -====== -Running LLFI on your target applications - You can follow the same flow as the Step D of Installation (Testing LLFI). For more details, you can follow the instructions on https://github.com/DependableSystemsLab/LLFI/wiki. - -====== -Read caveats.txt for caveats and known problems. diff --git a/README.md b/README.md index e8242e0e..ff99efe5 100755 --- a/README.md +++ b/README.md @@ -1,195 +1,177 @@ LLTFI ===== -LLTFI (Low-Level Tensor Fault Injector) is a unified SWiFI (Software-implemented fault injection) tool that supports fault injection of both C/C++ programs and ML applications written using high-level frameworks such as TensorFlow and PyTorch. - -As machine learning (ML) has become more prevalent across many critical domains, so has the need to understand ML system resilience. While there are many ML fault injectors at the application level, there has been little work enabling fault injection of ML applications at a lower level. **LLTFI** is a tool that allows users to run fault injection experiments on C/C++, TensorFlow and PyTorch applications at a lower level (at the LLVM IR level). Please refer to the following [paper](https://blogs.ubc.ca/dependablesystemslab/2021/08/31/wip-lltfi-low-level-tensor-fault-injector/) for more information about LLTFI. - -LLTFI is built on top of [LLFI](https://github.com/DependableSystemsLab/LLFI) and is fully backward compatible with it. - -### LLFI ### -**LLFI** is an LLVM-based fault injection tool, that injects faults into the LLVM IR of the application source code. The faults can be injected into specific program points, and the effect can be easily tracked back to the source code. LLFI is typically used to map fault characteristics back to source code and hence understand source level or program characteristics for various kinds of fault outcomes. Detailed documentation about LLFI can be found at: https://github.com/DependableSystemsLab/LLFI/wiki. Because LLTFI is designed to be backward compatible with LLFI, the basic setup instructions for LLTFI are similar to those of LLFI. However, there are additional steps and dependencies for running ML programs. - -LLTFI Workflow: -------------------------- -High-level ML models need to be lowered to intermediate representation (IR) for fault injection. LLTFI provides a single script that converts ML models into LLVM IR, using several publicly available tools and performs fault injection. -LLTFI first lowers ML models to **MLIR** (Multi-Level Intermediate Representation) using ONNX-MLIR before converting to LLVM IR. The reasons for choosing MLIR are MLIR's ability to better preserve the semantics of ML models, its integration with LLVM, testability and easier extensibility. - -#### Workflow Diagram of LLTFI: #### - -![Alt text](images/workflow.png?raw=true "Workflow Diagram of LLTFI") - -- LLTFI first converts all ML models to the ONNX format. ONNX’s open exchange format allows LLTFI to -support both TensorFlow and PyTorch. -- Then, the ONNX file is converted into MLIR through ONNX-MLIR. -- Finally, we convert MLIR into LLVM IR, using the mlir-translate tool in LLVM 15.0. - -**LLTFI** can now inject faults into the LLVM IR, alike lowered C/C++ programs. - -The LLFI tool was originally written for LLVM 3.4. While developing LLTFI, the entire LLFI tool was upgraded to LLVM 15.0 because LLVM 3.4 has no support for MLIR. -This upgrade also ensured that LLTFI is compatible with all of the newest C/C++ features, and LLVM optimization passes - - -Auto-Installer --------------- -If you wish to build LLTFI and its dependencies via the auto-installer(installer/InstallLLTFI.py), you *do not need* to clone the LLTFI git repository. Simply download the installer script by itself, and it will fetch the latest version of the git repository for you. To run the script, simply copy it into the directory where you would like to build the LLTFI and, from the command line, run `python3 InstallLLTFI.py`. - -Dependencies: - 1. 64 Bit Machine (preferably with GPU for faster training of ML programs) - 2. 64 bit Linux (Ubuntu 20.04) or OS X - 3. CMake (minimum v3.15) - 4. Python 3 and above - 5. Ninja >= 1.10.2 - 6. Internet Connection - -Usage: - 1. Copy the InstallLLTFI.py script to where you want to build the LLTFI. Run "python3 InstallLLTFI.py -h" to see all running options/guidelines - 2. Run "python3 InstallLLTFI.py" - - -Manual Installation -------------------- - -In this method, the developer has more control over the location of the LLVM build that the LLTFI requires. If you already have LLVM built, you could use that build. - -**Dependencies:** - - 1. 64 Bit Machine (preferably with GPU for faster training of ML programs) - 2. 64 bit Linux (Ubuntu 20.04) or OS X - 3. CMake (minimum v3.15) - 4. Python 3 and above - 5. Python YAML library (PyYAML v5.4.1 or higher, v6.0+ supported) - 6. Ninja >= 1.10.2 - 7. libprotoc >= 3.11.0 - 8. Clang v15.0 (commit: 9778ec057cf4) - 9. LLVM v15.0 (commit: 9778ec057cf4) ([Reference](http://llvm.org/docs/CMake.html)). - LLVM 15.0 takes a long time to completely build. Following is a shortcut to checking out the required LLVM commit, and building only the necessary LLVM targets. - ``` - git clone https://github.com/llvm/llvm-project.git - - # Check out a specific branch that is known to work with the required version of ONNX MLIR. - cd llvm-project && git checkout 9778ec057cf4 && cd .. - - mkdir llvm-project/build - cd llvm-project/build - - cmake -G Ninja ../llvm \ - -DLLVM_ENABLE_PROJECTS="clang;mlir" \ - -DLLVM_BUILD_TESTS=ON \ - -DLLVM_TARGETS_TO_BUILD="host" \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DLLVM_ENABLE_RTTI=ON - - cmake --build . --target clang check-mlir mlir-translate opt llc lli llvm-dis llvm-link -j 2 - - ninja install -j 2 - ``` - 10. For executing ML programs, following additional dependencies have to be installed: - 1. TensorFlow framework (v2.0 or greater) - 2. numpy package (part of TensorFlow) - 3. [tensorflow-onnx](https://github.com/onnx/tensorflow-onnx): - Installation with pip is sufficient - ``` - pip install tf2onnx - ``` - 4. libprotoc - ``` - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.17.2/protobuf-all-3.17.2.zip - unzip protobuf-all-3.17.2.zip - cd protobuf-3.17.2 - - ./configure - make -j 2 - make check - sudo make install - sudo ldconfig # refresh shared library cache. - ``` - 5. [ONNX-MLIR](https://github.com/onnx/onnx-mlir) - - Additional changes made in the ONNX-MLIR code are present in: https://github.com/DependableSystemsLab/onnx-mlir-lltfi. Clone this repo and checkout the `LLTFI` branch. The MLIR_DIR cmake variable must be set before building onnx-mlir. It should point to the mlir cmake module inside an llvm-project build or install directory (e.g., llvm-project/build/lib/cmake/mlir). - ``` - MLIR_DIR=$(pwd)/llvm-project/build/lib/cmake/mlir - ``` - - Onnx-mlir branch ``` LLTFI ``` has to be built and installed. - ``` - git clone --recursive https://github.com/DependableSystemsLab/onnx-mlir-lltfi.git - mv onnx-mlir-lltfi onnx-mlir && cd onnx-mlir - git checkout LLTFI - cd .. - - mkdir onnx-mlir/build && cd onnx-mlir/build - cmake -G Ninja \ - -DCMAKE_CXX_COMPILER=/usr/bin/c++ \ - -DMLIR_DIR=${MLIR_DIR} \ - .. - - cmake --build . - - # Run lit tests: - export LIT_OPTS=-v - cmake --build . --target check-onnx-lit - - ninja install - ``` - 11. GraphViz package (for visualizing error propagation) +LLTFI (Low-Level Tensor Fault Injector) is a unified SWiFI tool that supports +fault injection of both C/C++ programs and ML applications written using +high-level frameworks such as TensorFlow and PyTorch. Faults are injected at +the LLVM IR level, giving precise control over which instructions and registers +are targeted. +LLTFI is built on top of [LLFI](https://github.com/DependableSystemsLab/LLFI) +and is fully backward compatible with it. +For a detailed description of the internal architecture — pass pipeline, +selector class hierarchy, hardware/software/ML fault modes, runtime library, +and the interface between the compile-time and runtime layers — see +**[architecture.md](architecture.md)**. - +Please refer to the following +[paper](https://blogs.ubc.ca/dependablesystemslab/2021/08/31/wip-lltfi-low-level-tensor-fault-injector/) +for background on LLTFI. -### Building LLTFI: ### - - Run `./setup --help` for build instructions. -``` - $ ./setup --help - - Usage: setup OPTIONS - List of options: - -LLVM_DST_ROOT : - Make sure you build LLVM with CMake and pass build root directory here - -LLVM_SRC_ROOT - -LLFI_BUILD_ROOT - -LLVM_GXX_BIN_DIR (optional): - You don't need to set it if clang is already in system path +Repository Layout +----------------- - --help(-h): show help information - --runTests: Add this option if you want to run all regression tests after building LLTFI. ``` - - Below is the command to build LLTFI if `clang` is already in $PATH (replace paths with your actual directories): +llvm_passes/ LLVM pass plugin (llfi-passes.so) — compile-time only + core/ Pass infrastructure and selector framework + hardware_failures/ Built-in hardware fault instruction selectors + instruction_duplication/ SID pass for ML soft-error detection (SEDPasses.so) +runtime_lib/ C/C++ runtime library linked into instrumented binaries +bin/ Python driver scripts: instrument.py, profile.py, injectfault.py +tools/ Trace analysis, ML utilities + GenerateMakefile/ Test harness Makefile generator +docs/ tutorial_first_experiment.md — end-to-end C/C++ walkthrough and output guide + tutorial_ml_experiment.md — end-to-end ML/ONNX walkthrough (layer targeting, multi-fault) + adding_a_test.md — how to add a regression test case + input_yaml_guide.md — user guide for writing input.yaml + input_masterlist.yaml — full reference schema for input.yaml +test_suite/ Regression tests +sample_programs/ Example C/C++ and ML programs with input.yaml files +architecture.md Internal architecture reference for developers +CODING_GUIDELINES.md C++ and Python style rules +CONTRIBUTING.md How to set up a dev environment and submit changes +migration.md LLVM 15 → 19.x upgrade log ``` -./setup -LLFI_BUILD_ROOT /path/to/LLFI-build -LLVM_SRC_ROOT /path/to/llvm-project -LLVM_DST_ROOT /path/to/llvm-project/build + + +Dependencies +------------ + +1. 64-bit Linux (Ubuntu 20.04 or later) or macOS +2. CMake ≥ 3.15 +3. Python 3 +4. Python YAML library (PyYAML ≥ 5.4.1) +5. Ninja ≥ 1.10.2 +6. **Clang and LLVM b270525f730b** + + To build LLVM from source (required if you also need MLIR for onnx-mlir): + ```bash + git clone https://github.com/llvm/llvm-project.git + cd llvm-project && git checkout b270525f730b && cd .. + mkdir llvm-project/build && cd llvm-project/build + cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS="clang;mlir" \ + -DLLVM_BUILD_TESTS=ON \ + -DLLVM_TARGETS_TO_BUILD="host" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON + cmake --build . --target clang check-mlir mlir-translate opt llc lli \ + llvm-dis llvm-link -j$(nproc) + ``` + +7. **For ML programs** (all optional; tests skip gracefully when absent): + + | Dependency | Install | + |------------|---------| + | TensorFlow ≥ 2.0 | `pip install tensorflow` | + | tensorflow-onnx | `pip install tf2onnx` | + | PyTorch | `pip install torch` | + | ONNX | `pip install onnx` | + | pygraphviz, pydot | `pip install pygraphviz pydot` | + | libprotoc ≥ 3.11 | build from source (see below) | + | [ONNX-MLIR](https://github.com/DependableSystemsLab/onnx-mlir-lltfi) (LLTFI branch) | see below | + + **libprotoc:** + ```bash + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.17.2/protobuf-all-3.17.2.zip + unzip protobuf-all-3.17.2.zip && cd protobuf-3.17.2 + ./configure && make -j$(nproc) && sudo make install && sudo ldconfig + ``` + + **ONNX-MLIR** (LLTFI branch, requires an MLIR-enabled LLVM build): + ```bash + git clone --recursive https://github.com/DependableSystemsLab/onnx-mlir-lltfi.git + mv onnx-mlir-lltfi onnx-mlir && cd onnx-mlir && git checkout LLTFI && cd .. + MLIR_DIR=$(pwd)/llvm-project/build/lib/cmake/mlir + mkdir onnx-mlir/build && cd onnx-mlir/build + cmake -G Ninja -DCMAKE_CXX_COMPILER=/usr/bin/c++ -DMLIR_DIR=${MLIR_DIR} .. + cmake --build . && ninja install + ``` + +8. GraphViz (for dependency graph visualisation) + + +Installation +------------ + +Run `./setup --help` for a full option list. + ``` - On Ubuntu systems where LLVM is installed via apt, `clang` may only be available as `clang-15` (not `clang`) and will not be found automatically. In that case, pass `-LLVM_GXX_BIN_DIR` explicitly: +./setup -LLFI_BUILD_ROOT \ + -LLVM_SRC_ROOT \ + -LLVM_DST_ROOT ``` -./setup -LLFI_BUILD_ROOT /path/to/LLFI-build -LLVM_SRC_ROOT /path/to/llvm-project -LLVM_DST_ROOT /usr/lib/llvm-15 -LLVM_GXX_BIN_DIR /usr/lib/llvm-15/bin + +The build root must not already exist. Delete it first when rebuilding from +scratch. To rebuild after source changes without re-running setup: + +```bash +cd /path/to/LLTFI-build && make ``` -Details about running the Web GUI for LLTFI can be found [here](web-app/README.MD) -### Building LLTFI using Docker: ### +Docker +------ -`docker/Dockerfile` can be used to build and run LLTFI in a docker container. You can modify the Dockerfile according to your system and project requirements. More details can be found [here](docker/README.md) +`docker/Dockerfile` builds and runs LLTFI in a container. Copy the Dockerfile +outside the repository, then: -Steps to build: -1. **Creating a docker image from the Dockerfile:** Copy the Dockerfile to a directory of your choice outside this repository. To create an image, run the command `docker build --tag imageName .` in the terminal. -2. **Starting a docker container:** Once the above step is completed, a docker container can be started using the command `docker run -it imageName` +```bash +docker build --tag lltfi . +docker run -it lltfi +``` +See [docker/README.md](docker/README.md) for details. -### Running tests: ### -Running all regression tests after installation is highly recommended. Note that you may encounter some error messages during the fault injection stage. This is normal. Once all tests have been completed and they all passed, LLTFI is correctly installed. -For complete test of whole of LLTFI, please use LLTFI test suite and refer to the wiki page: [Test suite for regression test](https://github.com/DependableSystemsLab/LLTFI/wiki/Test-Suite-for-Regression-Test) for details. Tests must be run from the build directory: -``` +Running Tests +------------- + +Tests must be run from the **build** directory. Running all regression tests +after installation is strongly recommended. +Individual test categories can be run separately: + +```bash cd /test_suite -python3 SCRIPTS/llfi_test --all + +python3 SCRIPTS/llfi_test --all_cpp # 21 core tests (expected: 21/21 PASS) +python3 SCRIPTS/llfi_test --all_hardware_faults # hardware fault injection (8 tests) +python3 SCRIPTS/llfi_test --all_trace_tools_tests # trace analysis tools (3 tests) +python3 SCRIPTS/llfi_test --all_makefile_generation # Makefile generation (2 tests) +``` + +Error messages during fault injection runs are normal and expected. + +#### ML / ONNX tests (optional dependencies) + +```bash +python3 SCRIPTS/llfi_test --all_ml ``` +Tests that require missing dependencies are reported as **SKIP** (not FAIL) and +excluded from the pass/fail count. + +| Group | Requirements | +|-------|-------------| +| ML tool unit tests | `pip install onnx pygraphviz pydot` | +| Instruction duplication (synthetic IR) | LLTFI build only | +| Instruction duplication (real model IR) | `model.ll` from `sample_programs/.../mnist/compile.sh` | +| TensorFlow → ONNX | `pip install tensorflow tf2onnx onnx` | +| PyTorch → ONNX | `pip install torch onnx` | +| ONNX → LLVM IR | onnx-mlir binary (set `ONNX_MLIR_BUILD`) | +| Fault injection (ML) | LLTFI build + `model.ll` | + Results ------- -After fault injection, output from LLFI and the tested application can be found -in the *llfi* directory. -| Directory | Contents | -| ----------------------| ---------------------------------------------- | -| *std_output* | Piped STDOUT from the tested application | -| *llfi_stat_output* | Fault injection statistics and trace files | -| *error_output* | Failure reports (program crashes, hangs, etc.) | -| *baseline* | Golden output and profiling trace | -| *prog_output* | Disk output from faulty runs | +After fault injection, output is in the `llfi/` directory inside your program +folder. For a full description of each file see +[architecture.md — Interface Between the Two Layers](architecture.md). + +| Directory | Contents | +|-----------|----------| +| `std_output/` | Piped stdout from each run | +| `llfi_stat_output/` | Fault injection statistics, profiling data, trace files | +| `error_output/` | Failure reports (crashes, hangs, SDCs) | +| `baseline/` | Golden output and profiling trace | +| `prog_output/` | Disk output from faulty runs | + +Reproducing ISSRE'23 Experiments +--------------------------------- -Reproducing the experiments in our ISSRE'23 paper -------------------------------------------------- +See the [ISSRE'23 AE branch README](https://github.com/DependableSystemsLab/LLTFI/blob/ISSRE23_AE/README.md). -Please refer to the following README file for instructions on obtaining benchmarks and reproducing the experiments in our ISSRE'23 paper. [ISSRE'23 AE](https://github.com/DependableSystemsLab/LLTFI/blob/ISSRE23_AE/README.md) References ---------- + * [LLFI Paper](http://blogs.ubc.ca/karthik/2013/02/15/llfi-an-intermediate-code-level-fault-injector-for-soft-computing-applications/) * [LLFI Wiki](https://github.com/DependableSystemsLab/LLFI/wiki) * [LLTFI Wiki](https://github.com/DependableSystemsLab/LLTFI/wiki) -* Udit Kumar Agarwal, Abraham Chan, Karthik Pattabiraman. LLTFI: Framework agnostic fault injection for machine learning applications (Tools and Artifacts Track). International Symposium on Software Reliability Engineering (ISSRE), 2022. 10 pages. [LLTFI Paper](https://www.dropbox.com/s/lgr3ed75sy0fq2p/issre22-camera-ready.pdf?dl=0) -* Udit Kumar Agarwal, Abraham Chan, Karthik Pattabiraman. Resilience Assessment of Large Language Models under Transient Hardware Faults (PER). International Symposium on Software Reliability Engineering (ISSRE), 2023. [Paper](https://www.dropbox.com/scl/fi/mv6yehk0lctcz3l4efy0k/ISSRE23_Udit.pdf?rlkey=dzwbxk7js29pqjwirjj25ik8q&dl=0) +* Udit Kumar Agarwal, Abraham Chan, Karthik Pattabiraman. *LLTFI: Framework agnostic fault injection for machine learning applications.* ISSRE 2022. [PDF](https://www.dropbox.com/s/lgr3ed75sy0fq2p/issre22-camera-ready.pdf?dl=0) +* Udit Kumar Agarwal, Abraham Chan, Karthik Pattabiraman. *Resilience Assessment of Large Language Models under Transient Hardware Faults.* ISSRE 2023. [PDF](https://www.dropbox.com/scl/fi/mv6yehk0lctcz3l4efy0k/ISSRE23_Udit.pdf?rlkey=dzwbxk7js29pqjwirjj25ik8q&dl=0) + Citations ----------- +--------- -
+```bibtex
 @article{Agarwal22LLTFI,
-  title={LLTFI: Framework agnostic fault injection for machine learning applications (Tools and Artifacts Track)},
-  author={Agarwal, Udit and Chan, Abraham and Pattabiraman, Karthik},
-  journal={International Symposium on Software Reliability Engineering (ISSRE)},
-  year={2022},
-  publisher={IEEE}
+  title   = {LLTFI: Framework agnostic fault injection for machine learning
+             applications (Tools and Artifacts Track)},
+  author  = {Agarwal, Udit and Chan, Abraham and Pattabiraman, Karthik},
+  journal = {International Symposium on Software Reliability Engineering (ISSRE)},
+  year    = {2022},
+  publisher = {IEEE}
 }
-
+``` + +--- + +Read *caveats.txt* for known limitations and gotchas. -====== -Read *caveats.txt* for caveats and known problems. +Read *CODING_GUIDELINES.md* for C++, C, and Python coding conventions. +Read *architecture.md* for a detailed description of the internal architecture. diff --git a/WATERS_auto_conf.sh b/WATERS_auto_conf.sh deleted file mode 100644 index c5a19b3e..00000000 --- a/WATERS_auto_conf.sh +++ /dev/null @@ -1,93 +0,0 @@ -#!/bin/bash - -####################### Readme ####################### -# This script is for initializing a WATERS server. -# Basic development tools and CMake 2.8 will be inst- -# alled. LLVM 2.9 and a custom version of LLFI(for LL- -# VM 2.9) will be downloaded, build and installed. -# The default user account for this script is 'root'. -# So that it does not involve typing any password dur- -# ing the process -# -# Author: Qining -###################################################### - -export MAINDIR=/home - -## Install basic pacakges -yum update -y -yum install -y nautilus-open-terminal -yum install -y xauth -yum install -y dbus-x11 -yum groupinstall -y 'Fonts' -yum install -y gedit -yum groupinstall -y 'Development Tools' -yum install -y wget -yum install -y git - -## Install CMake 2.8 -cd $MAINDIR -mkdir cmake -wget http://www.cmake.org/files/v2.8/cmake-2.8.12.2-Linux-i386.sh -sh ./cmake-2.8.12.2-Linux-i386.sh --prefix=$MAINDIR/cmake --exclude-subdir -echo "export PATH=/home/cmake/bin:\$PATH">>/root/.bashrc -source /root/.bashrc -cmake --version - -mkdir $MAINDIR/Downloads - -## Compile, build and install llvm-3.4 and clang -cd $MAINDIR/Downloads -wget http://llvm.org/releases/3.4/llvm-3.4.src.tar.gz -wget http://llvm.org/releases/3.4/clang-3.4.src.tar.gz -tar -xvzf llvm-3.4.src.tar.gz -tar -xvzf clang-3.4.src.tar.gz -mv llvm-3.4 $MAINDIR/llvmsrc -mv clang-3.4 /home/llvmsrc/tools -mkdir $MAINDIR/llvm -cd $MAINDIR/llvm -cmake ../llvmsrc -DLLVM_REQUIRES_RTTI=1 -make -j24 -#make install -#echo "export PATH=/home/llvm/bin:\$PATH">>/root/.bashrc - -## Install python3 -cd $MAINDIR/Downloads -wget http://www.python.org/ftp/python/3.3.2/Python-3.3.2.tar.bz2 -tar jxvf Python-3.3.2.tar.bz2 -cd Python-3.3.2 -./configure -./make -j24 -./make install - -# install pyyaml -cd $MAINDIR/Downloads -wget http://pyyaml.org/download/pyyaml/PyYAML-3.11.tar.gz -tar -xvzf PyYAML-3.11.tar.gz -mv PyYAML-3.11 $MAINDIR/pyyamlsrc -cd $MAINDIR/pyyamlsrc -python3 setup.py install --prefix=$MAINDIR/pyyaml -echo "export PYTHONPATH=\$PYTHONPATH:/home/pyyaml/lib/python3.3/site-packages">>/root/.bashrc -source /root/.bashrc - -## install git 1.8 -cd $MAINDIR/Downloads -wget http://git-core.googlecode.com/files/git-1.8.3.4.tar.gz -wget -O git-manpages-1.8.3.4.tar.gz http://code.google.com/p/git-core/downloads/detail?name=git-manpages-1.8.3.4.tar.gz&can-2&q= -yum install -y zlib-devel perl-CPAN gettext curl-devel -tar xvfz git-1.8.3.4.tar.gz -cd git-1.8.3.4 -./configure -make -j24 -make install -git --version - -## Install LLFI -cd $MAINDIR/Downloads -git clone -b merge https://github.com/DependableSystemsLab/LLFI.git llfisrc -mv llfisrc $MAINDIR/llfisrc -cd $MAINDIR/llfisrc -./setup -LLVM_DST_ROOT $MAINDIR/llvm -LLVM_SRC_ROOT $MAINDIR/llvmsrc -LLFI_BUILD_ROOT $MAINDIR/llfi -LLVM_GXX_BIN_DIR $MAINDIR/llvm/bin -#echo "export PATH=\$PATH:/home/llfi/bin">>/root/.bashrc -#source /root/.bashrc - diff --git a/architecture.md b/architecture.md new file mode 100644 index 00000000..1ae486bd --- /dev/null +++ b/architecture.md @@ -0,0 +1,526 @@ +# LLTFI Architecture + +LLTFI is an LLVM-based fault injection framework for C/C++ and ML applications. +It is split into two completely independent layers that run at different times: +the **compile-time layer** (LLVM passes that transform IR) and the +**runtime layer** (a C library linked into the instrumented binary). These layers +communicate through a small set of well-defined interfaces: LLVM metadata embedded +in the IR, a config file written by the driver scripts, and log files written by +the runtime. + +--- + +## High-Level Workflow + +``` + Source code / LLVM IR + │ + ▼ + ┌─────────────────────────────────────────────────────────┐ + │ bin/instrument.py (reads input.yaml) │ + │ invokes opt with --passes="genllfiindexpass, │ + │ profilingpass,faultinjectionpass[,insttracepass]" │ + └──────────────────────────┬──────────────────────────────┘ + │ + ┌─────────────────▼─────────────────┐ + │ llfi-passes.so │ + │ (LLVM pass plugin – see §2) │ + └─────────────────┬─────────────────┘ + │ produces instrumented IR + ▼ + ┌──────────────────────────┐ + │ Instrumented binary │ + │ (IR linked against │ + │ libllfi-rt.so – §3) │ + └──────┬───────────┬───────┘ + │ │ + ┌─────────▼──┐ ┌─────▼──────────────┐ + │ prof.exe │ │ fi.exe │ + │ (profiling)│ │ (fault injection) │ + └─────────┬──┘ └─────┬──────────────┘ + │ │ + llfi.stat.prof.txt llfi.stat.fi.injectedfaults.txt + (opcode counts, llfi.stat.trace.txt (optional) + ML layer timings) prog_output/ +``` + +The driver scripts (`profile.py`, `injectfault.py`) orchestrate execution and +collect results. `tracediff.py` and related tools in `tools/` perform +post-processing. + +--- + +## 1. Repository Layout + +``` +llvm_passes/ Compile-time: LLVM pass plugin source + core/ Pass infrastructure and selector framework + hardware_failures/ Built-in hardware fault instruction selectors + instruction_duplication/ SID pass (SEDPasses.so) + RegisterPasses.cpp New-PM plugin entry point for llfi-passes.so + CustomTensorOperatorInstSelector.cpp ML-specific inst selector + MainGraphInstSelector.cpp ML-specific inst selector + +runtime_lib/ Runtime: C/C++ library linked into instrumented binaries + FaultInjectionLib.c Core fault injection (preFunc, injectFunc) + ProfilingLib.cpp Opcode profiling (doProfiling, endProfiling) + InstTraceLib.c Instruction tracing (printInstTracer) + FaultInjectorManager.h/cpp Plugin registry for custom fault injectors + MLFaultInjectionLib.cpp ML layer tracking (lltfiMLLayer) + +bin/ Driver scripts + instrument.py Compiles IR with LLFI passes; produces prof.exe + fi.exe + profile.py Runs prof.exe; collects opcode frequencies + injectfault.py Runs fi.exe repeatedly with configured fault parameters + batchInstrument.py Runs instrument.py across multiple programs in one call + batchProfile.py Runs profile.py across multiple programs + batchInjectfault.py Runs injectfault.py across multiple programs + HardwareFailureAutoScan.py Lists applicable hardware selectors for a program + InjectorAutoScan.py Lists all registered fault injector names (fi_type values) + llfi-gui.py Launches the LLFI graphical front-end + +tools/ Post-processing and ML utilities + GenerateMakefile/ Test harness Makefile generator + tracediff.py Compares golden vs faulty instruction traces + traceontograph.py Overlays traces onto dependency graph + ExtendONNXModel.py Prepares ONNX model for LLTFI instrumentation + outputONNXGraph.py Visualises ONNX model graph + compiletoIR.py Converts source to LLVM IR + +docs/ Reference schemas and user guides + tutorial_first_experiment.md End-to-end C/C++ walkthrough and output interpretation + tutorial_ml_experiment.md End-to-end ML/ONNX walkthrough (layer targeting, multi-fault) + adding_a_test.md How to add a regression test case + input_yaml_guide.md Prose guide to writing input.yaml (start here) + input_masterlist.yaml Full key reference for input.yaml (C/C++ programs) + input_masterlist_ml.yaml Full key reference for ML programs + +test_suite/ Regression tests +``` + +--- + +## 2. Compile-Time Layer: LLVM Pass Plugin + +Everything in `llvm_passes/` runs at compile time, inside `opt`. It is built into +two shared libraries: + +| Library | Contents | +|---------|----------| +| `llfi-passes.so` | All core passes, all hardware fault selectors | +| `SEDPasses.so` | Selective Instruction Duplication (ML only) | + +### 2.1 Pass Pipeline + +The standard instrumentation pipeline runs three passes in sequence. +`instrument.py` invokes `opt` as: + +``` +opt -load-pass-plugin llfi-passes.so \ + --passes="genllfiindexpass,profilingpass,faultinjectionpass" \ + [selector options] input.ll -o output.ll +``` + +The passes must run in this order because each depends on the previous: + +``` +GenLLFIIndexPass + Assigns a unique integer (the LLFI index) to every instruction in the + module. The index is stored as LLVM metadata on the instruction and + read at runtime to identify the instruction being executed. + Output file: llfi.stat.totalindex.txt + │ + ▼ +ProfilingPass + Inserts a call to doProfiling(opcode) before each FI-candidate + instruction. Also inserts endProfiling() at program exit points. + For ML models, inserts lltfiMLLayer() calls around OMInstrumentPoint + boundaries to track per-layer timing. + │ + ▼ +FaultInjectionPass + Inserts preFunc() / injectFunc() calls around each FI-candidate + register. The selector framework (§2.2) determines which instructions + and registers are candidates. +``` + +Optional passes can be added: + +``` +InstTracePass + Inserts printInstTracer() calls to record register values at runtime. + Used for trace-based analysis with tracediff.py. + +LLFIDotGraphPass (registered as "dotgraphpass") + Generates llfi.stat.graph.dot, a Graphviz data-dependency graph of the + module. Added to the pass list when genDotGraph: true is set in + input.yaml. Used by the zgrviewer-based graph viewer. +``` + +Two additional passes are used only by the auto-scan scripts and not in the +normal instrumentation pipeline: + +``` +HardwareFailureAutoScanPass Writes llfi.applicable.hardware.selectors.txt +``` + +### 2.2 Instruction and Register Selector Framework + +The selector framework decides *which* instructions and *which* register +positions within those instructions to inject faults into. It is the central +extensibility point of LLTFI. + +#### Class Hierarchy + +``` +FIInstSelector (llvm_passes/core/FIInstSelector.h) +│ virtual bool isInstFITarget(Instruction*) = 0 +│ virtual void getCompileTimeInfo(map&) +│ void getFIInsts(Module&, set*) — calls isInstFITarget +│ +└── HardwareFIInstSelector + │ (for hardware fault modes — any instruction may be the target) + ├── InstTypeFIInstSelector by opcode (fadd, load, store, …) + ├── FuncNameFIInstSelector all insts in a named function + ├── LLFIIndexFIInstSelector one specific LLFI index + ├── CustomTensorOperatorInstSelector ONNX operator boundary (ML) + └── MainGraphInstSelector all arith insts in main_graph (ML) + +FIRegSelector (llvm_passes/core/FIRegSelector.h) +│ virtual bool isRegofInstFITarget(Value*, Instruction*) = 0 +│ void getFIInstRegMap(set, map>*) +│ +└── HardwareFIRegSelector + └── RegLocBasedFIRegSelector dstreg / srcreg1–4 / allreg / allsrcreg +``` + +#### Controller + +`Controller` (singleton, `llvm_passes/core/Controller.h`) is the glue between +the pass pipeline and the selector framework. It parses LLVM command-line +options set by `instrument.py`, instantiates the appropriate selector objects, +and exposes the final `fi_inst_regs_map` (a mapping of `Instruction*` → +`list` register positions) to `ProfilingPass` and `FaultInjectionPass`. + +Key options it parses: + +| Option | Values | Meaning | +|--------|--------|---------| +| `-fiinstselmethod` | `insttype`, `funcname`, `custominstselector` | Which inst selector to use | +| `-includeinst` / `-excludeinst` | opcode names | Filter by instruction type | +| `-includefunc` / `-excludefunc` | function names | Filter by function | +| `-fiinstselectorname` | selector name string | Used with `custominstselector` | +| `-firegselmethod` | `regloc`, `customregselector` | Which reg selector to use | +| `-fireglocation` | `dstreg`, `srcreg1`–`srcreg4`, `allreg`, `allsrcreg` | Which register position | + +#### Custom Selector Manager + +`FICustomInstSelectorManager` and `FICustomRegSelectorManager` are singleton +registries. Hardware selectors register themselves at static +initialisation time via: + +```cpp +static RegisterFIInstSelector X("mymode", new MyModeInstSelector()); +static RegisterFIRegSelector Y("mymode", new MyModeRegSelector()); +``` + +The manager resolves a name string (e.g. `"insttype"`, `"WrongPointer(Data)"`) +to a concrete selector object at instrumentation time. + +### 2.3 Hardware Fault Selectors + +Hardware fault selectors model low-level physical faults (bit-flips, stuck-at +bits) that can affect any instruction. They are built into `llfi-passes.so` +and are always available. + +| Selector name | Class | Targets | +|---------------|-------|---------| +| `insttype` | `InstTypeFIInstSelector` | All instructions matching a set of LLVM opcodes | +| `funcname` | `FuncNameFIInstSelector` | All instructions in one or more named functions | +| `llfiindex` | `LLFIIndexFIInstSelector` | The single instruction with a given LLFI index | +| `maingraph` | `MainGraphInstSelector` | FAdd / FMul / FCmp in `main_graph()` (ML) | +| `CustomTensorOperator` | `CustomTensorOperatorInstSelector` | FP arith inside a named ONNX operator region (ML) | + +Register selection for all hardware selectors is handled by +`RegLocBasedFIRegSelector`, controlled by the `-fireglocation` option. + +The `HardwareFailureAutoScanPass` (invoked by `HardwareFailureAutoScan.py`) +enumerates all registered hardware selectors and writes the list to +`llfi.applicable.hardware.selectors.txt`. + +### 2.4 ML Fault Selectors + +ML fault injection operates on LLVM IR compiled from ONNX models via onnx-mlir. +The onnx-mlir compiler annotates the IR with `@OMInstrumentPoint(operator_id, flag)` +calls that delimit each tensor operator's computation region. LLTFI's ML +selectors use these boundaries to confine fault injection to a specific layer. + +#### CustomTensorOperatorInstSelector + +Registered as `"CustomTensorOperator"`. Selects `FAdd`, `FSub`, `FMul`, +`FDiv`, and `FCmp` instructions inside `main_graph()` that fall between a +matching `OMInstrumentPoint(id, 2)` (start) and `OMInstrumentPoint(id, 1)` +(end) pair. The operator is identified by name (e.g. `conv`, `relu`, `matmul`) +mapped to ONNX operator IDs. + +Command-line options (set via `input.yaml`): +- `--layerName=conv;relu` — target these operator types +- `--layerNo=2;0` — target the 2nd `conv` and all `relu` occurrences (0 = all) + +#### MainGraphInstSelector + +Registered as `"maingraph"`. Simpler selector that targets all `FAdd`, `FMul`, +and `FCmp` instructions anywhere in `main_graph()`, without operator-boundary +awareness. Used when operator-level granularity is not required. + +### 2.5 Selective Instruction Duplication (SEDPasses.so) + +The `InstructionDuplicationPass` in `SEDPasses.so` is a separate pass plugin +for soft-error detection and correction in ML models. It is not part of the +normal fault injection pipeline — it is applied to the model IR *before* +instrumentation. + +For each selected arithmetic instruction, the pass: +1. Duplicates the instruction (clone + insert immediately after) +2. Inserts a call to `compareFloatValues(original, duplicate)` which returns + the bitwise AND of both results +3. Replaces uses of the original result with the `compareFloatValues` return + +When both copies agree (no fault), `compareFloatValues(x, x) == x` so the +result is unchanged. When they disagree (transient fault in one copy), the AND +masks the corrupted bits. + +The pass supports two modes: +- **AID** (Arithmetic Instruction Duplication): each selected instruction + duplicated independently +- **ACD** (Arithmetic Chain Duplication, `--enableChainDuplication`): + consecutive arithmetic sequences duplicated as a unit and compared only at + the end of the chain + +`compareFloatValues` is defined in +`llvm_passes/instruction_duplication/shared_lib/SIDHelperFunctions.cpp` and +linked into the model IR via `llvm-link` before execution. + +--- + +## 3. Runtime Layer: libllfi-rt + +The runtime library is a set of C/C++ files compiled to a shared library +(`libllfi-rt.so`) and linked into every instrumented binary. It is +**completely independent of LLVM** — it runs inside the target program, not +inside `opt`. + +The runtime reads its configuration from `llfi.config.runtime.txt` (written +by `injectfault.py` before each run) and writes its results to log files in +the `llfi/` output directory. + +### 3.1 Fault Injection Runtime (FaultInjectionLib.c) + +This is the core of the runtime layer. The instrumented IR calls two functions +around each fault-injection-candidate register: + +```c +// Called before the instruction executes. +// Returns true if this dynamic instance should be injected, false otherwise. +bool preFunc(long llfi_index, unsigned opcode, + unsigned my_reg_index, unsigned total_reg_target_num); + +// Called after preFunc returns true. Corrupts the register value in-place. +void injectFunc(long llfi_index, unsigned size, char *buf, + unsigned my_reg_index, unsigned reg_pos, char *opcode_str); +``` + +**`preFunc` selection logic:** +The runtime uses either *cycle-based* or *index-based* targeting (set in +`llfi.config.runtime.txt`). In cycle-based mode it counts every call across +all instructions and injects when the cycle counter matches `fi_cycle`. In +index-based mode it injects the first dynamic occurrence of the instruction +with the given LLFI index. A fault is injected at most once per instruction +execution (even when multiple registers are targeted). + +**`injectFunc` fault types** (set by `fi_type` in config): + +| `fi_type` | Operation | +|-----------|-----------| +| `bitflip` | XOR a randomly selected bit | +| `stuck_at_0` | AND the bit to force it to 0 | +| `stuck_at_1` | OR the bit to force it to 1 | + +After injection, the runtime appends a record to +`llfi.stat.fi.injectedfaults.txt` with the LLFI index, register size, bit +position, and fault type. + +**Config file keys** (`llfi.config.runtime.txt`): + +| Key | Meaning | +|-----|---------| +| `fi_type` | Fault type (see table above) | +| `fi_cycle` | Inject at this dynamic instruction cycle | +| `fi_index` | Inject at this LLFI index (alternative to `fi_cycle`) | +| `fi_reg_index` | Target register index within the instruction (random if absent) | +| `fi_bit` | Target bit position (random if absent) | +| `fi_num_bits` | Number of bits to corrupt (default 1) | +| `fi_max_multiple` | Number of injection points for multi-fault experiments | +| `fi_next_cycle` | Additional cycles for multi-fault injection | + +### 3.2 Fault Injector Plugin Registry (FaultInjectorManager) + +Software fault modes that need custom injection logic register a `FaultInjector` +subclass with the singleton `FaultInjectorManager`. The manager resolves the +injector name from `fi_type` in the config file to the corresponding +`injectFault()` implementation. + +The runtime injector registrations live in two tracked files: + +- `runtime_lib/CommonFaultInjectors.cpp` — the three hardware injectors + (`bitflip`, `stuck_at_0`, `stuck_at_1`). + +### 3.3 Profiling Runtime (ProfilingLib.cpp) + +```c +void doProfiling(int opcode); // Inserted before each FI-candidate inst +void endProfiling(); // Inserted at program exit +``` + +`doProfiling` increments a per-opcode counter weighted by an estimated cycle +cost. `endProfiling` writes `llfi.stat.prof.txt`: + +``` +total_cycle= +``` + +For ML models, `MLFaultInjectionLib.cpp` provides: + +```c +void lltfiMLLayer(int64_t layerName, int64_t start); +``` + +Called at each `OMInstrumentPoint` boundary, it records the start and end +cycle of each tensor operator. The profiling data drives `injectfault.py`'s +fault space sampling (which dynamic instruction cycles to target). + +### 3.4 Instruction Trace Runtime (InstTraceLib.c) + +```c +void printInstTracer(long instID, char *opcode, int size, char *ptr, int maxPrints); +``` + +Writes a line to `llfi.stat.trace.txt` for every call. In the *golden run* +(no fault) all instructions are traced. In the *fault injection run* tracing +is gated by a state machine: it starts after the fault is injected and records +the next `maxPrints` instructions. The delta between the two trace files is the +input to `tracediff.py`. + +--- + +## 4. Interface Between the Two Layers + +The compile-time and runtime layers share three interfaces. There is no shared +header between them — the contract is purely by convention. + +### 4.1 LLVM Metadata (compile-time → runtime) + +`GenLLFIIndexPass` stores each instruction's LLFI index as LLVM metadata: + +```cpp +// Written by GenLLFIIndexPass: +inst->setMetadata("llfi_index", MDNode::get(ctx, ConstantAsMetadata::get( + ConstantInt::get(Type::getInt64Ty(ctx), index)))); +``` + +`FaultInjectionPass` reads this metadata when generating the `preFunc` / +`injectFunc` calls, embedding the index as a constant argument. The runtime +never parses metadata — it receives the index as a plain integer argument. + +### 4.2 Runtime Config File (driver → runtime) + +`injectfault.py` writes `llfi.config.runtime.txt` immediately before launching +`fi.exe`. The runtime reads it at startup in `initInjections()`. No LLVM types +or headers are involved — it is a plain text key=value file. + +### 4.3 Log Files (runtime → driver / post-processing tools) + +All output files are written by the runtime to the `llfi/` directory created +by `instrument.py`: + +``` +llfi/ + llfi_stat_output/ + llfi.stat.totalindex.txt Total instructions (GenLLFIIndexPass) + llfi.stat.prof.txt Profiling results (ProfilingLib) + llfi.stat.fi.injectedfaults.txt Injection log (FaultInjectionLib) + llfi.stat.trace.*.txt Instruction traces (InstTraceLib) + std_output/ Stdout from each run + error_output/ Stderr / crash info from each run + prog_output/ Disk output from faulty runs + baseline/ Golden-run output and trace +``` + +--- + +## 5. Adding a New Fault Mode + +### New hardware fault mode (inst selector only) + +1. Subclass `HardwareFIInstSelector` in a new `.cpp` under `llvm_passes/`: + ```cpp + class MySelector : public HardwareFIInstSelector { + bool isInstFITarget(Instruction* inst) override { ... } + void getCompileTimeInfo(map& info) override { ... } + }; + static RegisterFIInstSelector X("mymode", new MySelector()); + ``` +2. Add the file to `llvm_passes/CMakeLists.txt` under the `llfi-passes` target. +3. Rebuild (`make` in the build root). + +### New fault injector (runtime) + +Subclass `FaultInjector` in `runtime_lib/` and register it: +```cpp +class MySoftwareInjector : public FaultInjector { + void injectFault(long index, unsigned size, unsigned fi_bit, + char *buf) override { ... } +}; +static RegisterFaultInjector R("MySoftwareInjector", new MySoftwareInjector()); +``` + +Set `fi_type=MySoftwareInjector` in `input.yaml` to select it at runtime. + +--- + +## 6. Key Design Decisions + +**Selector registration at static init time.** Both inst and reg selectors +register themselves via `static RegisterFI*Selector` objects, which run before +`main()`. This means new selectors are available as soon as they are linked +into `llfi-passes.so` — no central registry to update, no switch statement to +extend. + +**Two-phase runtime check.** `preFunc` and `injectFunc` are separate calls. +`preFunc` is cheap (counter comparison) and runs every time. `injectFunc` is +called only when `preFunc` returns true, keeping the hot path overhead minimal. + +**LLFI index as the universal identifier.** Every instruction in the module gets +a unique stable integer at compile time. This index is the only way the +compile-time and runtime layers refer to the same instruction — no function +names, no IR text, no debug info dependency. + +**ML instrumentation is non-invasive to the core.** The ML-specific selectors +(`CustomTensorOperatorInstSelector`, `MainGraphInstSelector`) are ordinary +`HardwareFIInstSelector` subclasses that happen to look for `OMInstrumentPoint` +calls. The core `FaultInjectionPass` and runtime are unchanged for ML workloads. + +**New pass manager only (legacy PM removed).** The legacy `opt -load` / +`-enable-new-pm=0` interface was dropped in LLVM 17 and is no longer supported +in LLTFI. All passes — including `InstructionDuplication` — use the new pass +manager (`PassInfoMixin`, `llvmGetPassPluginInfo`). This removes the need to +maintain two registration paths for every pass and aligns with LLVM's own +direction for pass infrastructure. + +**SEDPasses.so is separate from llfi-passes.so.** The +`InstructionDuplicationPass` lives in its own plugin so it can be applied to +the model IR *before* LLFI instrumentation. The SED transformation alters +instruction counts and structure; if it ran inside the LLFI pass pipeline (after +`GenLLFIIndexPass` has already assigned indices) those indices would be +invalidated. A separate library makes the mandatory pre-instrumentation ordering +explicit and prevents accidental composition in the wrong sequence. diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index bb758f6c..bd9f3b3e 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,7 +7,6 @@ project(bin) copy(instrument.py instrument) copy(injectfault.py injectfault) copy(profile.py profile) -copy(SoftwareFailureAutoScan.py SoftwareFailureAutoScan) copy(batchInstrument.py batchInstrument) copy(batchProfile.py batchProfile) copy(batchInjectfault.py batchInjectfault) diff --git a/bin/HardwareFailureAutoScan.py b/bin/HardwareFailureAutoScan.py index e26db449..9feb1ed6 100755 --- a/bin/HardwareFailureAutoScan.py +++ b/bin/HardwareFailureAutoScan.py @@ -20,14 +20,11 @@ import os import subprocess import sys -from subprocess import call -import yaml script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) +sys.path.append(os.path.join(script_path, "../config")) import llvm_paths - optbin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/opt") llcbin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/llc") llfipasses = os.path.join(script_path, "../llvm_passes/llfi-passes.so") @@ -40,56 +37,65 @@ # directory of the target IR basedir = "" + def parseArgs(args): - global basedir - global options - global filename - - cwd = os.getcwd() - for i, arg in enumerate(args): - option = arg - if os.path.isfile(arg): - basedir = os.path.realpath(os.path.dirname(arg)) - option = os.path.basename(arg) - options.append(option) - elif arg.startswith('-outputfilename='): - filename = arg.split('-outputfilename=')[-1] - options.append('-hardwarescan_outputfilename='+filename) - os.chdir(basedir) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + global basedir + global filename + + for arg in args: + option = arg + if os.path.isfile(arg): + basedir = os.path.realpath(os.path.dirname(arg)) + option = os.path.basename(arg) + options.append(option) + elif arg.startswith("-outputfilename="): + filename = arg.split("-outputfilename=")[-1] + options.append("-hardwarescan_outputfilename=" + filename) + os.chdir(basedir) + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def runAutoScan(args): - global filename - execlist = [optbin , "-load", llfipasses, "-HardwareFailureAutoScanPass", "-analyze"] + execlist = [ + optbin, + "-load-pass-plugin", + llfipasses, + "--passes=HardwareFailureAutoScanPass", + "-disable-output", + ] execlist.extend(args) - print(' '.join(execlist)) + print(" ".join(execlist)) p = subprocess.Popen(execlist) p.wait() if p.returncode != 0: print("ERROR: Hardware Auto scan pass return code !=0\n") - exit(p.returncode) - elif os.path.isfile(os.path.join(basedir, filename)) == False: - print("ERROR: No output file found at: "+os.path.join(basedir, filename)+"!\n") - exit(1) + sys.exit(p.returncode) + elif not os.path.isfile(os.path.join(basedir, filename)): + print( + "ERROR: No output file found at: " + os.path.join(basedir, filename) + "!\n" + ) + sys.exit(1) return 0 def main(args): parseArgs(args) - r = runAutoScan(options) + runAutoScan(options) return 0 + if __name__ == "__main__": - if len(sys.argv[1:]) < 1 or sys.argv[1] == '--help' or sys.argv[1] == '-h': - usage() - sys.exit(0) - r = main(sys.argv[1:]) - sys.exit(r) + if len(sys.argv[1:]) < 1 or sys.argv[1] == "--help" or sys.argv[1] == "-h": + usage() + sys.exit(0) + r = main(sys.argv[1:]) + sys.exit(r) diff --git a/bin/InjectorAutoScan.py b/bin/InjectorAutoScan.py index 01f4ee28..872a9ac1 100755 --- a/bin/InjectorAutoScan.py +++ b/bin/InjectorAutoScan.py @@ -19,12 +19,9 @@ import os import subprocess import sys -from subprocess import call -import yaml script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) -import llvm_paths +sys.path.append(os.path.join(script_path, "../config")) injector_scanner_bin = os.path.join(script_path, "../runtime_lib/InjectorScanner") prog = os.path.basename(sys.argv[0]) @@ -35,57 +32,60 @@ # directory of the target IR basedir = "" + def parseArgs(args): - global basedir - global options - global filename - - cwd = os.getcwd() - for i, arg in enumerate(args): - option = arg - if os.path.isfile(arg): - basedir = os.path.realpath(os.path.dirname(arg)) - option = os.path.basename(arg) - options.append(option) - elif arg.startswith('-outputfilename='): - filename = arg.split('-outputfilename=')[-1] - - options.extend(['-o', filename]) - os.chdir(basedir) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + global basedir + global filename + + for arg in args: + option = arg + if os.path.isfile(arg): + basedir = os.path.realpath(os.path.dirname(arg)) + option = os.path.basename(arg) + options.append(option) + elif arg.startswith("-outputfilename="): + filename = arg.split("-outputfilename=")[-1] + + options.extend(["-o", filename]) + os.chdir(basedir) + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def runAutoScan(args): - global filename execlist = [injector_scanner_bin] execlist.extend(args) - print(' '.join(execlist)) + print(" ".join(execlist)) p = subprocess.Popen(execlist) p.wait() if p.returncode != 0: print("ERROR: FaultInjector Auto scan pass return code !=0\n") - exit(p.returncode) - elif os.path.isfile(os.path.join(basedir, filename)) == False: - print("ERROR: No output file found at: "+os.path.join(basedir, filename)+"!\n") - exit(1) + sys.exit(p.returncode) + elif not os.path.isfile(os.path.join(basedir, filename)): + print( + "ERROR: No output file found at: " + os.path.join(basedir, filename) + "!\n" + ) + sys.exit(1) return 0 def main(args): parseArgs(args) - r = runAutoScan(options) + runAutoScan(options) return 0 + if __name__ == "__main__": - if len(sys.argv[1:]) < 1 or sys.argv[1] == '--help' or sys.argv[1] == '-h': - usage() - sys.exit(0) - r = main(sys.argv[1:]) - sys.exit(r) + if len(sys.argv[1:]) < 1 or sys.argv[1] == "--help" or sys.argv[1] == "-h": + usage() + sys.exit(0) + r = main(sys.argv[1:]) + sys.exit(r) diff --git a/bin/SoftwareFailureAutoScan.py b/bin/SoftwareFailureAutoScan.py deleted file mode 100755 index 407a7e9a..00000000 --- a/bin/SoftwareFailureAutoScan.py +++ /dev/null @@ -1,153 +0,0 @@ -#! /usr/bin/env python3 - -""" -%(prog)s takes a single IR file as input and scan all instructions to find potential applicable target points for fault injection, and to create applicable failure modes list. - -Usage: %(prog)s [OPTIONS] - -List of options: - --outputfilename=: set the name of the file that stores the list of applicable software failures (default: llfi.applicable.software.failures.txt) - Note: If is a relative path instead of an absolute path, the base path of will be the path of the targeting IR file instead of the calling path. - --numOfRuns : set the number of runs for each found failure mode (default: 1) ---enable_tracing: enable tracing ---enable_forward_injection: enable injection on the forward slice of the selected injection point ---enable_backward_injection: enable injection on the backward slice of the selected injection point ---no_input_yaml: do not generate an master input.yaml automatically. ---help: print this message. - -""" - - -import os -import subprocess -import sys -from subprocess import call -import yaml - -script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) -import llvm_paths - - -optbin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/opt") -llcbin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/llc") -llfipasses = os.path.join(script_path, "../llvm_passes/llfi-passes.so") -llfilinklib = os.path.join(script_path, "../runtime_lib") -prog = os.path.basename(sys.argv[0]) -# option list for AutoScan pass -options = [] -# output file name of AutoScan pass -filename = "llfi.applicable.software.failures.txt" -# directory of the target IR -basedir = "" -# input.yaml generation -run_num_dict = {'numOfRuns': 1} -tracing_dict = {'tracingPropagation':False, 'tracingPropagationOption':{'generateCDFG':False}} -trace_injection_dict = {'includeInjectionTrace':[]} - -no_input_yaml_flag = False - -def parseArgs(args): - global basedir - global options - global filename - global no_input_yaml_flag - - cwd = os.getcwd() - for i, arg in enumerate(args): - option = arg - if os.path.isfile(arg): - basedir = os.path.realpath(os.path.dirname(arg)) - option = os.path.basename(arg) - options.append(option) - elif arg.startswith('-outputfilename='): - filename = arg.split('-outputfilename=')[-1] - options.append('-softwarescan_outputfilename='+filename) - elif arg == "-numOfRuns": - run_num_dict['numOfRuns'] = int(args[i+1]) - elif arg == "--enable_tracing": - tracing_dict['tracingPropagation'] = True - tracing_dict['tracingPropagationOption']['generateCDFG'] = True - elif arg == "--enable_backward_injection": - trace_injection_dict['includeInjectionTrace'].append('forward') - elif arg == "--enable_forward_injection": - trace_injection_dict['includeInjectionTrace'].append('backward') - elif arg == "--no_input_yaml": - no_input_yaml_flag = True - os.chdir(basedir) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) - -def runAutoScan(args): - global filename - execlist = [optbin , "-load", llfipasses, "-genllfiindexpass", "-SoftwareFailureAutoScanPass", "-enable-new-pm=0"] - execlist.extend(args) - print(' '.join(execlist)) - p = subprocess.Popen(execlist) - p.wait() - if p.returncode != 0: - print("ERROR: Software Auto scan pass return code !=0\n") - exit(p.returncode) - elif os.path.isfile(os.path.join(basedir, filename)) == False: - print("ERROR: No output file found at: "+os.path.join(basedir, filename)+"!\n") - exit(1) - return 0 - -def generateInputYaml(): - global filename - selector_list = [] - with open(os.path.join(basedir, filename)) as f: - for line in f.readlines()[1:]: - selector_list.append(line.split('-')[-1].strip()) - customInstselector_dict = {'customInstselector':{'include':selector_list}} - yaml_dict = { - 'compileOption':{ - 'instSelMethod':[customInstselector_dict], - 'regSelMethod':'customregselector', - 'customRegSelector':'Automatic', - }, - 'runOption':[{ - 'run':{ - 'fi_type':'AutoInjection' - } - }] - } - yaml_dict['compileOption'].update(tracing_dict) - yaml_dict['compileOption'].update(trace_injection_dict) - yaml_dict['runOption'][0]['run'].update(run_num_dict) - yaml_text = yaml.dump(yaml_dict, default_flow_style=False) - with open(os.path.join(basedir, 'input.yaml'), 'w') as f: - f.write(yaml_text) - return 0 - -def cleanDir(): - global basedir - stale_config_file_path = os.path.join(basedir, 'llfi.config.compiletime.txt') - if os.path.isfile(stale_config_file_path): - os.remove(stale_config_file_path) - -def main(args): - global no_input_yaml_flag - - parseArgs(args) - r = runAutoScan(options) - if no_input_yaml_flag == False: - s = generateInputYaml() - cleanDir() - return 0 - -if __name__ == "__main__": - if len(sys.argv[1:]) < 1 or sys.argv[1] == '--help' or sys.argv[1] == '-h': - usage() - sys.exit(0) - r = main(sys.argv[1:]) - sys.exit(r) diff --git a/bin/batchInjectfault.py b/bin/batchInjectfault.py index 3635dcf8..87e890c0 100755 --- a/bin/batchInjectfault.py +++ b/bin/batchInjectfault.py @@ -10,101 +10,110 @@ You need to run \'batchInstrument\' first, then run %(prog)s under the same directory, the directory that contains multiple sub directories for different software faults. Same as \'batchInstrument\', %(prog)s is only applicable when multiple software failure modes are defined in input.yaml. """ -import sys, os, shutil +import sys, os import yaml import subprocess prog = os.path.basename(sys.argv[0]) script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) -import llvm_paths -injectfault_script = os.path.join(script_path, 'injectfault') +injectfault_script = os.path.join(script_path, "injectfault") # basedir and options are assigned in parseArgs(args) basedir = "" options = [] + def parseArgs(args): - global basedir - global options - cwd = os.getcwd() - for arg in args: - option = arg - if os.path.isfile(arg): - basedir = os.path.realpath(os.path.dirname(arg)) - option = os.path.basename(arg) - options.append(option) - os.chdir(basedir) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = -1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + global basedir + for arg in args: + option = arg + if os.path.isfile(arg): + basedir = os.path.realpath(os.path.dirname(arg)) + option = os.path.basename(arg) + options.append(option) + os.chdir(basedir) + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = -1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def phraseMasterYaml(): - master_yaml_dict = {} - model_list = [] - try: - with open('input.yaml', 'r') as master_yaml_file: - master_yaml_dict = yaml.safe_load(master_yaml_file) - except: - print ("ERROR: Unable to find input.yaml or load the input.yaml under current directory") - print (basedir) - sys.exit(-1) - try: - model_list = list(master_yaml_dict['compileOption']['instSelMethod'][0]['customInstselector']['include']) - except: - print ("ERROR: this wrapper script is not applicable on the input.yaml under current directory. Please note this script is only applicable on input.yaml files with multiple software failure models defined.") - print (basedir) - sys.exit(-1) - return master_yaml_dict, model_list + master_yaml_dict = {} + model_list = [] + try: + with open("input.yaml", "r") as master_yaml_file: + master_yaml_dict = yaml.safe_load(master_yaml_file) + except Exception: + print( + "ERROR: Unable to find input.yaml or load the input.yaml under current directory" + ) + print(basedir) + sys.exit(-1) + try: + model_list = list( + master_yaml_dict["compileOption"]["instSelMethod"][0]["customInstselector"][ + "include" + ] + ) + except Exception: + print( + "ERROR: this wrapper script is not applicable on the input.yaml under current directory. Please note this script is only applicable on input.yaml files with multiple software failure models defined." + ) + print(basedir) + sys.exit(-1) + return master_yaml_dict, model_list + def callInjectfault(model_list, *argv): - num_failed = 0 - for model in model_list: - workdir = os.path.join(basedir, "llfi-"+model) - try: - os.chdir(workdir) - except: - print ("ERROR: Unable to change to directory:", workdir) - sys.exit(-1) - faultinjection_exe_name = argv[0] - if faultinjection_exe_name.endswith('.ll'): - faultinjection_exe_name = faultinjection_exe_name.split('.ll')[0] - elif faultinjection_exe_name.endswith('.bc'): - faultinjection_exe_name = faultinjection_exe_name.split('.bc')[0] - faultinjection_exe_name = faultinjection_exe_name + '-faultinjection.exe' - command = [injectfault_script] - command.extend(['./llfi/'+faultinjection_exe_name]) - command.extend(argv[1:]) - print ("\nRun injectfault command:", ' '.join(command)) - try: - o = subprocess.check_output(command, stderr=sys.stderr) - except subprocess.CalledProcessError: - print ("injectfault:", model, " failed!") - num_failed += 1 - else: - print (o.decode()) - print ("injectfault:", model, " succeed!") - os.chdir(basedir) - return num_failed + num_failed = 0 + for model in model_list: + workdir = os.path.join(basedir, "llfi-" + model) + try: + os.chdir(workdir) + except Exception: + print("ERROR: Unable to change to directory:", workdir) + sys.exit(-1) + faultinjection_exe_name = argv[0] + if faultinjection_exe_name.endswith(".ll"): + faultinjection_exe_name = faultinjection_exe_name.split(".ll")[0] + elif faultinjection_exe_name.endswith(".bc"): + faultinjection_exe_name = faultinjection_exe_name.split(".bc")[0] + faultinjection_exe_name = faultinjection_exe_name + "-faultinjection.exe" + command = [injectfault_script] + command.extend(["./llfi/" + faultinjection_exe_name]) + command.extend(argv[1:]) + print("\nRun injectfault command:", " ".join(command)) + try: + o = subprocess.check_output(command, stderr=sys.stderr) + except subprocess.CalledProcessError: + print("injectfault:", model, " failed!") + num_failed += 1 + else: + print(o.decode()) + print("injectfault:", model, " succeed!") + os.chdir(basedir) + return num_failed + def main(*argv): - global options - parseArgs(argv) - master_yaml_dict, model_list = phraseMasterYaml() - r = callInjectfault(model_list, *options) - return r + parseArgs(argv) + master_yaml_dict, model_list = phraseMasterYaml() + r = callInjectfault(model_list, *options) + return r + if __name__ == "__main__": - if len(sys.argv[1:]) < 1 or sys.argv[1] == '--help' or sys.argv[1] == '-h': - usage() - sys.exit(0) - else: - argv = sys.argv[1:] - r = main(*argv) - sys.exit(r) \ No newline at end of file + if len(sys.argv[1:]) < 1 or sys.argv[1] == "--help" or sys.argv[1] == "-h": + usage() + sys.exit(0) + else: + argv = sys.argv[1:] + r = main(*argv) + sys.exit(r) diff --git a/bin/batchInstrument.py b/bin/batchInstrument.py index 199055cf..3928bbfb 100755 --- a/bin/batchInstrument.py +++ b/bin/batchInstrument.py @@ -16,7 +16,7 @@ --help(-h): Show help information Prerequisite: -You need to have 'input.yaml' under the same directory as , which contains appropriate options for LLFI. Usually, this command is only applicable for input.yaml file with a list of software failure modes included, i.e. using customInstSelector in instSelMethod and including software fault instruction selector (e.g. BufferOverflow(API)). +You need to have 'input.yaml' under the same directory as , which contains appropriate options for LLFI. Usually, this command is only applicable for input.yaml file with a list of software failure modes included, i.e. using customInstSelector in instSelMethod. """ import sys, os, shutil @@ -25,138 +25,149 @@ prog = os.path.basename(sys.argv[0]) script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) -import llvm_paths -instrument_script = os.path.join(script_path, 'instrument') +instrument_script = os.path.join(script_path, "instrument") # basedir and options are assigned in parseArgs(args) basedir = "" options = [] + def parseArgs(args): - global basedir - global options - cwd = os.getcwd() - for arg in args: - option = arg - if os.path.isfile(arg): - basedir = os.path.realpath(os.path.dirname(arg)) - option = os.path.basename(arg) - options.append(option) - os.chdir(basedir) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = -1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + global basedir + for arg in args: + option = arg + if os.path.isfile(arg): + basedir = os.path.realpath(os.path.dirname(arg)) + option = os.path.basename(arg) + options.append(option) + os.chdir(basedir) + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = -1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def parseMasterYaml(): - global basedir - master_yaml_dict = {} - model_list = [] - try: - with open('input.yaml', 'r') as master_yaml_file: - master_yaml_dict = yaml.safe_load(master_yaml_file) - except: - print ("ERROR: Unable to find input.yaml or load the input.yaml under basedir directory") - print (basedir) - sys.exit(-1) - try: - model_list = list(master_yaml_dict['compileOption']['instSelMethod'][0]['customInstselector']['include']) - except: - print ("ERROR: this wrapper script is not applicable on the input.yaml under current directory. Please note this script is only applicable on input.yaml files with multiple software failure models defined.") - print (basedir) - sys.exit(-1) - return master_yaml_dict, model_list + master_yaml_dict = {} + model_list = [] + try: + with open("input.yaml", "r") as master_yaml_file: + master_yaml_dict = yaml.safe_load(master_yaml_file) + except Exception: + print( + "ERROR: Unable to find input.yaml or load the input.yaml under basedir directory" + ) + print(basedir) + sys.exit(-1) + try: + model_list = list( + master_yaml_dict["compileOption"]["instSelMethod"][0]["customInstselector"][ + "include" + ] + ) + except Exception: + print( + "ERROR: this wrapper script is not applicable on the input.yaml under current directory. Please note this script is only applicable on input.yaml files with multiple software failure models defined." + ) + print(basedir) + sys.exit(-1) + return master_yaml_dict, model_list + def splitMasterYaml(master_yaml_dict, model_list): - global basedir - for model in model_list: - include_list = [model] - slave_yaml_dict = dict(master_yaml_dict) - slave_yaml_dict['compileOption']['instSelMethod'][0]['customInstselector']['include'] = include_list - slave_yaml_text = yaml.dump(slave_yaml_dict, default_flow_style=False) - workdir = os.path.join(basedir, "llfi-"+model) - try: - with open(os.path.join(workdir, 'input.yaml'), 'w') as f: - f.write(slave_yaml_text) - except: - print ("ERROR: Unable to write slave input.yaml file for model: ", model) - print ("workdir: ", workdir) - sys.exit(-1) - return 0 + for model in model_list: + include_list = [model] + slave_yaml_dict = dict(master_yaml_dict) + slave_yaml_dict["compileOption"]["instSelMethod"][0]["customInstselector"][ + "include" + ] = include_list + slave_yaml_text = yaml.dump(slave_yaml_dict, default_flow_style=False) + workdir = os.path.join(basedir, "llfi-" + model) + try: + with open(os.path.join(workdir, "input.yaml"), "w") as f: + f.write(slave_yaml_text) + except Exception: + print("ERROR: Unable to write slave input.yaml file for model: ", model) + print("workdir: ", workdir) + sys.exit(-1) + return 0 + def maybeRequired(abs_path): - basename = os.path.basename(abs_path) - if basename.startswith('llfi'): - return False - elif basename == 'input.yaml': - return False - return True + basename = os.path.basename(abs_path) + if basename.startswith("llfi"): + return False + elif basename == "input.yaml": + return False + return True + def prepareDirs(model_list): - global basedir - stuffs_under_basedir = [f for f in os.listdir(basedir) if maybeRequired(os.path.join(basedir, f))] - for model in model_list: - workdir = os.path.join(basedir, "llfi-"+model) - if os.path.exists(workdir): - try: - if os.path.isdir(workdir): - shutil.rmtree(workdir) - else: - os.remove(workdir) - except: - print ("ERROR: Unable to remove:", workdir, "for model:", model) - sys.exit(-1) - os.makedirs(workdir) - for s in stuffs_under_basedir: - s_path = os.path.join(basedir, s) - try: - if os.path.isfile(s_path): - shutil.copy(s_path, workdir) - else: - shutil.copytree(s_path, workdir) - except: - print ("ERROR: Unable to copy:", s_path, "\nto:", workdir) - sys.exit(-1) - return 0 + stuffs_under_basedir = [ + f for f in os.listdir(basedir) if maybeRequired(os.path.join(basedir, f)) + ] + for model in model_list: + workdir = os.path.join(basedir, "llfi-" + model) + if os.path.exists(workdir): + try: + if os.path.isdir(workdir): + shutil.rmtree(workdir) + else: + os.remove(workdir) + except Exception: + print("ERROR: Unable to remove:", workdir, "for model:", model) + sys.exit(-1) + os.makedirs(workdir) + for s in stuffs_under_basedir: + s_path = os.path.join(basedir, s) + try: + if os.path.isfile(s_path): + shutil.copy(s_path, workdir) + else: + shutil.copytree(s_path, workdir) + except Exception: + print("ERROR: Unable to copy:", s_path, "\nto:", workdir) + sys.exit(-1) + return 0 def callInstrument(model_list): - global basedir - global options - num_failed = 0 - for model in model_list: - workdir = os.path.join(basedir, "llfi-"+model) - os.chdir(workdir) - command = [instrument_script] - command.extend(options) - try: - o = subprocess.check_output(command, stderr=sys.stderr) - except subprocess.CalledProcessError: - print ("instrumenting:", model, " failed!") - num_failed += 1 - else: - print (o.decode()) - print ("instrumenting:", model, " succeed!") - os.chdir(basedir) - return num_failed + num_failed = 0 + for model in model_list: + workdir = os.path.join(basedir, "llfi-" + model) + os.chdir(workdir) + command = [instrument_script] + command.extend(options) + try: + o = subprocess.check_output(command, stderr=sys.stderr) + except subprocess.CalledProcessError: + print("instrumenting:", model, " failed!") + num_failed += 1 + else: + print(o.decode()) + print("instrumenting:", model, " succeed!") + os.chdir(basedir) + return num_failed + def main(): - parseArgs(sys.argv[1:]) - master_yaml_dict, model_list = parseMasterYaml() - prepareDirs(model_list) - splitMasterYaml(master_yaml_dict, model_list) - r = callInstrument(model_list) - return r + parseArgs(sys.argv[1:]) + master_yaml_dict, model_list = parseMasterYaml() + prepareDirs(model_list) + splitMasterYaml(master_yaml_dict, model_list) + r = callInstrument(model_list) + return r + if __name__ == "__main__": - if len(sys.argv[1:]) < 1 or sys.argv[1] == '--help' or sys.argv[1] == '-h': - usage() - sys.exit(0) - r = main() - sys.exit(r) \ No newline at end of file + if len(sys.argv[1:]) < 1 or sys.argv[1] == "--help" or sys.argv[1] == "-h": + usage() + sys.exit(0) + r = main() + sys.exit(r) diff --git a/bin/batchProfile.py b/bin/batchProfile.py index 48fea10a..ad22cddd 100755 --- a/bin/batchProfile.py +++ b/bin/batchProfile.py @@ -10,102 +10,111 @@ You need to run \'batchInstrument\' first, then run %(prog)s under the same directory, the directory that contains multiple sub directories for different software faults. Same as \'batchInstrument\', %(prog)s is only applicable when multiple software failure modes are defined in input.yaml. """ -import sys, os, shutil +import sys, os import yaml import subprocess prog = os.path.basename(sys.argv[0]) script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) -import llvm_paths -profile_script = os.path.join(script_path, 'profile') +profile_script = os.path.join(script_path, "profile") # basedir and options are assigned in parseArgs(args) basedir = "" options = [] + def parseArgs(args): - global basedir - global options - cwd = os.getcwd() - for arg in args: - option = arg - if os.path.isfile(arg): - basedir = os.path.realpath(os.path.dirname(arg)) - option = os.path.basename(arg) - options.append(option) - os.chdir(basedir) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = -1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + global basedir + for arg in args: + option = arg + if os.path.isfile(arg): + basedir = os.path.realpath(os.path.dirname(arg)) + option = os.path.basename(arg) + options.append(option) + os.chdir(basedir) + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = -1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def phraseMasterYaml(): - master_yaml_dict = {} - model_list = [] - try: - with open('input.yaml', 'r') as master_yaml_file: - master_yaml_dict = yaml.safe_load(master_yaml_file) - except: - print ("ERROR: Unable to find input.yaml or load the input.yaml under current directory") - print (basedir) - sys.exit(-1) - try: - model_list = list(master_yaml_dict['compileOption']['instSelMethod'][0]['customInstselector']['include']) - except: - print ("ERROR: this wrapper script is not applicable on the input.yaml under current directory. Please note this script is only applicable on input.yaml files with multiple software failure models defined.") - print (basedir) - sys.exit(-1) - return master_yaml_dict, model_list + master_yaml_dict = {} + model_list = [] + try: + with open("input.yaml", "r") as master_yaml_file: + master_yaml_dict = yaml.safe_load(master_yaml_file) + except Exception: + print( + "ERROR: Unable to find input.yaml or load the input.yaml under current directory" + ) + print(basedir) + sys.exit(-1) + try: + model_list = list( + master_yaml_dict["compileOption"]["instSelMethod"][0]["customInstselector"][ + "include" + ] + ) + except Exception: + print( + "ERROR: this wrapper script is not applicable on the input.yaml under current directory. Please note this script is only applicable on input.yaml files with multiple software failure models defined." + ) + print(basedir) + sys.exit(-1) + return master_yaml_dict, model_list + def callProfile(model_list, *argv): - num_failed = 0 - for model in model_list: - workdir = os.path.join(basedir, "llfi-"+model) - try: - os.chdir(workdir) - except: - print ("ERROR: Unable to change to directory:", workdir) - sys.exit(-1) - profile_exe_name = argv[0] - print (profile_exe_name) - if profile_exe_name.endswith('.ll'): - profile_exe_name = profile_exe_name.split('.ll')[0] - elif profile_exe_name.endswith('.bc'): - profile_exe_name = profile_exe_name.split('.bc')[0] - profile_exe_name = profile_exe_name + '-profiling.exe' - command = [profile_script] - command.extend(['./llfi/'+profile_exe_name]) - command.extend(argv[1:]) - print ("\nRun profiling command:", ' '.join(command)) - try: - o = subprocess.check_output(command, stderr=sys.stderr) - except subprocess.CalledProcessError: - print ("profiling:", model, " failed!") - num_failed += 1 - else: - print (o.decode()) - print ("profiling:", model, " succeed!") - os.chdir(basedir) - return num_failed + num_failed = 0 + for model in model_list: + workdir = os.path.join(basedir, "llfi-" + model) + try: + os.chdir(workdir) + except Exception: + print("ERROR: Unable to change to directory:", workdir) + sys.exit(-1) + profile_exe_name = argv[0] + print(profile_exe_name) + if profile_exe_name.endswith(".ll"): + profile_exe_name = profile_exe_name.split(".ll")[0] + elif profile_exe_name.endswith(".bc"): + profile_exe_name = profile_exe_name.split(".bc")[0] + profile_exe_name = profile_exe_name + "-profiling.exe" + command = [profile_script] + command.extend(["./llfi/" + profile_exe_name]) + command.extend(argv[1:]) + print("\nRun profiling command:", " ".join(command)) + try: + o = subprocess.check_output(command, stderr=sys.stderr) + except subprocess.CalledProcessError: + print("profiling:", model, " failed!") + num_failed += 1 + else: + print(o.decode()) + print("profiling:", model, " succeed!") + os.chdir(basedir) + return num_failed + def main(*argv): - global options - parseArgs(argv) - master_yaml_dict, model_list = phraseMasterYaml() - r = callProfile(model_list, *options) - return r + parseArgs(argv) + master_yaml_dict, model_list = phraseMasterYaml() + r = callProfile(model_list, *options) + return r + if __name__ == "__main__": - if len(sys.argv[1:]) < 1 or sys.argv[1] == '--help' or sys.argv[1] == '-h': - usage() - sys.exit(0) - else: - argv = sys.argv[1:] - r = main(*argv) - sys.exit(r) \ No newline at end of file + if len(sys.argv[1:]) < 1 or sys.argv[1] == "--help" or sys.argv[1] == "-h": + usage() + sys.exit(0) + else: + argv = sys.argv[1:] + r = main(*argv) + sys.exit(r) diff --git a/bin/injectfault.py b/bin/injectfault.py index 33566944..911e9010 100755 --- a/bin/injectfault.py +++ b/bin/injectfault.py @@ -18,13 +18,16 @@ # This script injects faults the program and produces output # This script should be run after the profiling step -import sys, os, subprocess -import yaml -import time +import os import random import shutil +import subprocess +import sys +import time from subprocess import TimeoutExpired +import yaml + runOverride = False optionlist = [] defaultTimeout = 500 @@ -37,637 +40,722 @@ fi_exe = "" options = { - "verbose": False, + "verbose": False, } -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) def parseArgs(args): - global optionlist, fi_exe - if args[0] == "--help" or args[0] == "-h": - usage() - fi_exe = os.path.realpath(args[0]) - basedir = basedir = os.path.abspath(os.path.dirname(os.path.dirname(fi_exe))) - optionlist = args[1:] - - # remove the directory prefix for input files, this is to make it easier for the program - # to take a snapshot - for index, opt in enumerate(optionlist): - if os.path.isfile(opt): - if os.path.realpath(os.path.dirname(opt)) != basedir: - usage("File %s passed through option is not under current directory" % opt) - else: - optionlist[index] = os.path.basename(opt) - - if basedir != os.getcwd(): - print("Change directory to:", basedir) - os.chdir(basedir) + global optionlist, fi_exe + if args[0] == "--help" or args[0] == "-h": + usage() + fi_exe = os.path.realpath(args[0]) + basedir = basedir = os.path.abspath(os.path.dirname(os.path.dirname(fi_exe))) + optionlist = args[1:] + + # remove the directory prefix for input files, this is to make it easier for the program + # to take a snapshot + for index, opt in enumerate(optionlist): + if os.path.isfile(opt): + if os.path.realpath(os.path.dirname(opt)) != basedir: + usage( + "File %s passed through option is not under current directory" % opt + ) + else: + optionlist[index] = os.path.basename(opt) + + if basedir != os.getcwd(): + print("Change directory to:", basedir) + os.chdir(basedir) def checkInputYaml(): - global doc - global defaultTimeout - #Check for input.yaml's presence - yamldir = os.path.dirname(os.path.dirname(fi_exe)) - try: - f = open(os.path.join(basedir, 'input.yaml'),'r') - except: - usage("No input.yaml file in the parent directory of fault injection executable") - exit(1) - - #Check for input.yaml's correct formmating - try: - doc = yaml.safe_load(f) - except: - f.close() - usage("input.yaml is not formatted in proper YAML (reminder: use spaces, not tabs)") - exit(1) - finally: - f.close() - - if "kernelOption" in doc: - for opt in doc["kernelOption"]: - if opt=="forceRun": - runOverride = True - print("Kernel: Forcing run") - if "defaultTimeout" in doc: - defaultTimeout = int(doc["defaultTimeout"]) - assert defaultTimeout > 0, "The timeOut option must be greater than 0" - else: - print("Default timeout is set to " + str(defaultTimeout) + " by default.") + global doc + global defaultTimeout + global runOverride + # Check for input.yaml's presence + try: + with open(os.path.join(basedir, "input.yaml"), "r") as f: + doc = yaml.safe_load(f) + except OSError: + usage( + "No input.yaml file in the parent directory of fault injection executable" + ) + sys.exit(1) + except Exception: + usage( + "input.yaml is not formatted in proper YAML (reminder: use spaces, not tabs)" + ) + sys.exit(1) + + if "kernelOption" in doc: + for opt in doc["kernelOption"]: + if opt == "forceRun": + runOverride = True + print("Kernel: Forcing run") + if "defaultTimeout" in doc: + defaultTimeout = int(doc["defaultTimeout"]) + assert defaultTimeout > 0, "The timeOut option must be greater than 0" + else: + print("Default timeout is set to " + str(defaultTimeout) + " by default.") def print_progressbar(idx, nruns): - pct = (float(idx) / float(nruns)) - WIDTH = 50 - bar = "=" * int(pct * WIDTH) - bar += ">" - bar += "-" * (WIDTH - int(pct * WIDTH)) - print(("\r[%s] %.1f%% (%d / %d)" % (bar, pct * 100, idx, nruns)), end='\n') - sys.stdout.flush() + pct = float(idx) / float(nruns) + WIDTH = 50 + bar = "=" * int(pct * WIDTH) + bar += ">" + bar += "-" * (WIDTH - int(pct * WIDTH)) + print(("\r[%s] %.1f%% (%d / %d)" % (bar, pct * 100, idx, nruns)), end="\n") + sys.stdout.flush() ################################################################################ def config(): - global inputdir, outputdir, errordir, stddir, llfi_stat_dir - # config - llfi_dir = os.path.dirname(fi_exe) - inputdir = os.path.join(llfi_dir, "prog_input") - outputdir = os.path.join(llfi_dir, "prog_output") - errordir = os.path.join(llfi_dir, "error_output") - stddir = os.path.join(llfi_dir, "std_output") - llfi_stat_dir = os.path.join(llfi_dir, "llfi_stat_output") - - if not os.path.isdir(outputdir): - os.mkdir(outputdir) - if not os.path.isdir(errordir): - os.mkdir(errordir) - if not os.path.isdir(inputdir): - os.mkdir(inputdir) - if not os.path.isdir(stddir): - os.mkdir(stddir) - if not os.path.isdir(llfi_stat_dir): - os.mkdir(llfi_stat_dir) + global inputdir, outputdir, errordir, stddir, llfi_stat_dir + # config + llfi_dir = os.path.dirname(fi_exe) + inputdir = os.path.join(llfi_dir, "prog_input") + outputdir = os.path.join(llfi_dir, "prog_output") + errordir = os.path.join(llfi_dir, "error_output") + stddir = os.path.join(llfi_dir, "std_output") + llfi_stat_dir = os.path.join(llfi_dir, "llfi_stat_output") + + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + if not os.path.isdir(errordir): + os.mkdir(errordir) + if not os.path.isdir(inputdir): + os.mkdir(inputdir) + if not os.path.isdir(stddir): + os.mkdir(stddir) + if not os.path.isdir(llfi_stat_dir): + os.mkdir(llfi_stat_dir) ################################################################################ -def execute( execlist, timeout): - global outputfile - global return_codes - print(' '.join(execlist)) - #get state of directory - dirSnapshot() - p = subprocess.Popen(execlist, stdout = subprocess.PIPE) - outputFile = open(outputfile, "wb") - program_timed_out = False - start_time = 0 - elapsetime = 0 - - #communicate() will block until program exits or until timeout is reached - try: - start_time = time.time() - (p_stdout,p_stderr) = p.communicate(timeout=timeout) - elapsetime = int(time.time() - start_time + 1) - except TimeoutExpired: #Child process timed out - p.kill() #Need to kill the process and then clean up commmunication - (p_stdout,p_stderr) = p.communicate(timeout=timeout) - program_timed_out = True - - moveOutput() - if program_timed_out: - print("\tParent : Child timed out. Cleaning up ... ") - else: - print("\t program finish", p.returncode) - print("\t time taken", elapsetime,"\n") - outputFile = open(outputfile, "wb") - - if program_timed_out: - outputFile.write( - bytes("\n\n ### Process killed by LLFI for timing out ###\n","UTF-8")) - - outputFile.write(p_stdout) - - if program_timed_out: - outputFile.write( - bytes("\n\n ### Process killed by LLFI for timing out ###\n","UTF-8")) - - outputFile.close() - replenishInput() #for cases where program deletes input or alters them each run - - # Keep a dict of all return codes received. - if program_timed_out: - if "TO" in return_codes: - return_codes["TO"] += 1 +def execute(execlist, timeout): + print(" ".join(execlist)) + # get state of directory + dirSnapshot() + p = subprocess.Popen(execlist, stdout=subprocess.PIPE) + program_timed_out = False + start_time = 0 + elapsetime = 0 + + # communicate() will block until program exits or until timeout is reached + try: + start_time = time.time() + p_stdout, p_stderr = p.communicate(timeout=timeout) + elapsetime = int(time.time() - start_time + 1) + except TimeoutExpired: # Child process timed out + p.kill() # Need to kill the process and then clean up commmunication + p_stdout, p_stderr = p.communicate(timeout=timeout) + program_timed_out = True + + moveOutput() + if program_timed_out: + print("\tParent : Child timed out. Cleaning up ... ") else: - return_codes["TO"] = 1 - else: - if p.returncode in return_codes: - return_codes[p.returncode] += 1 + print("\t program finish", p.returncode) + print("\t time taken", elapsetime, "\n") + with open(outputfile, "wb") as outputFile: + if program_timed_out: + outputFile.write( + bytes("\n\n ### Process killed by LLFI for timing out ###\n", "UTF-8") + ) + + outputFile.write(p_stdout) + + if program_timed_out: + outputFile.write( + bytes("\n\n ### Process killed by LLFI for timing out ###\n", "UTF-8") + ) + replenishInput() # for cases where program deletes input or alters them each run + + # Keep a dict of all return codes received. + if program_timed_out: + if "TO" in return_codes: + return_codes["TO"] += 1 + else: + return_codes["TO"] = 1 else: - return_codes[p.returncode] = 1 - - if program_timed_out: - return "timed-out" - else: - return str(p.returncode) + if p.returncode in return_codes: + return_codes[p.returncode] += 1 + else: + return_codes[p.returncode] = 1 + if program_timed_out: + return "timed-out" + else: + return str(p.returncode) ################################################################################ def storeInputFiles(): - global inputList - inputList=[] - ##========Consider comma as separator of arguments ================================== - temp_optionlist = [] - for item in optionlist: - if item.count(',') == 0: - temp_optionlist.append(item) - else: - temp_optionlist.extend(item.split(',')) - ##=================================================================================== - for opt in temp_optionlist: - if os.path.isfile(opt):#stores all files in inputList and copy over to inputdir - shutil.copy2(opt, os.path.join(inputdir, opt)) - inputList.append(opt) + global inputList + inputList = [] + ##========Consider comma as separator of arguments ================================== + temp_optionlist = [] + for item in optionlist: + if item.count(",") == 0: + temp_optionlist.append(item) + else: + temp_optionlist.extend(item.split(",")) + ##=================================================================================== + for opt in temp_optionlist: + if os.path.isfile( + opt + ): # stores all files in inputList and copy over to inputdir + shutil.copy2(opt, os.path.join(inputdir, opt)) + inputList.append(opt) + ################################################################################ -def replenishInput():#TODO make condition to skip this if input is present - for each in inputList: - if not os.path.isfile(each):#copy deleted inputfiles back to basedir - shutil.copy2(os.path.join(inputdir, each), each) +def replenishInput(): # TODO make condition to skip this if input is present + for each in inputList: + if not os.path.isfile(each): # copy deleted inputfiles back to basedir + shutil.copy2(os.path.join(inputdir, each), each) + ################################################################################ def moveOutput(): - #move all newly created files - newfiles = [_file for _file in os.listdir(".")] - for each in newfiles: - if each not in dirBefore: - fileSize = os.stat(each).st_size - if fileSize == 0 and each.startswith("llfi"): - #empty library output, can delete - print(each+ " is going to be deleted for having size of " + str(fileSize)) - os.remove(each) - else: - flds = each.split(".") - newName = '.'.join(flds[0:-1]) - newName+='.'+run_id+'.'+flds[-1] - if newName.startswith("llfi"): - os.rename(each, os.path.join(llfi_stat_dir, newName)) - else: - os.rename(each, os.path.join(outputdir, newName)) + # move all newly created files + newfiles = [_file for _file in os.listdir(".")] + for each in newfiles: + if each not in dirBefore: + fileSize = os.stat(each).st_size + if fileSize == 0 and each.startswith("llfi"): + # empty library output, can delete + print( + each + " is going to be deleted for having size of " + str(fileSize) + ) + os.remove(each) + else: + flds = each.split(".") + newName = ".".join(flds[0:-1]) + newName += "." + run_id + "." + flds[-1] + if newName.startswith("llfi"): + os.rename(each, os.path.join(llfi_stat_dir, newName)) + else: + os.rename(each, os.path.join(outputdir, newName)) + ################################################################################ def dirSnapshot(): - #snapshot of directory before each execute() is performed - global dirBefore - dirBefore = [_file for _file in os.listdir(".")] + # snapshot of directory before each execute() is performed + global dirBefore + dirBefore = [_file for _file in os.listdir(".")] + ################################################################################ def readCycles(): - global totalcycles, fi_ml_stats - profinput= open("llfi.stat.prof.txt","r") - - while 1: - line = profinput.readline() - if not line: - break - if line.strip(): - # skip comments - if line.startswith("#"): - continue - label, value = line.split("=") - - # parse the line - if label == 'total_cycle': - totalcycles = int(value) - elif label == 'ml_layer': - layerNum, layerName, cycleStart, cycleEnd = value.split(",") - fi_ml_stats.append([int(layerNum), layerName, int(cycleStart), int(cycleEnd)]) - - profinput.close() + global totalcycles + with open("llfi.stat.prof.txt", "r") as profinput: + while 1: + line = profinput.readline() + if not line: + break + if line.strip(): + # skip comments + if line.startswith("#"): + continue + label, value = line.split("=") + + # parse the line + if label == "total_cycle": + totalcycles = int(value) + elif label == "ml_layer": + layerNum, layerName, cycleStart, cycleEnd = value.split(",") + fi_ml_stats.append( + [int(layerNum), layerName, int(cycleStart), int(cycleEnd)] + ) + ################################################################################ -def checkValues(key, val, var1 = None,var2 = None,var3 = None,var4 = None): - #preliminary input checking for fi options - #also checks for fi_bit usage by non-kernel users - #optional var# are used for fi_bit's case only - if key =='run_number': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val)>0, key+" must be greater than 0 in input.yaml" - - elif key == 'fi_type': - pass - - ##======== Add number of corrupted bits QINING @MAR 13th======== - elif key == 'fi_num_bits': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >=1, key+" must be greater than or equal to 1 in input.yaml" - ##============================================================== - - ##======== Add second corrupted regs QINING @MAR 27th=========== - elif key == "window_len": - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >=0, key+" must be greater than or equal to zero in input.yaml" - ##================================================================== - - ##BEHROOZ: Add max number of target locations - elif key == "fi_max_multiple": - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >0, key+" must be greater than zero in input.yaml" - assert int(val) <=int(fi_max_multiple_default), key+" must be smaller than or equal to "+str(fi_max_multiple_default)+ " in input.yaml" - ##============================================================== - - ##BEHROOZ: Add multiple corrupted regs - elif key == "window_len_multiple": - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >0, key+" must be greater than zero in input.yaml" - elif key == "window_len_multiple_startindex": - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >0, key+" must be greater than zero in input.yaml" - elif key == "window_len_multiple_endindex": - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >0, key+" must be greater than zero in input.yaml" - - ##============================================================== - - elif key == 'fi_cycle': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - ##BEHROOZ: I changed the below line to the current one to fix the fi_cycle - assert int(val) > 0, key+" must be greater than 0 in input.yaml" - #assert int(val) >= 0, key+" must be greater than or equal to 0 in input.yaml" - assert int(val) <= int(totalcycles), key +" must be less than or equal to "+totalcycles.strip()+" in input.yaml" - - elif key == 'fi_index': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >= 0, key+" must be greater than or equal to 0 in input.yaml" - - elif key == 'fi_reg_index': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >= 0, key+" must be greater than or equal to 0 in input.yaml" - - elif key == 'fi_bit': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >= 0, key+" must be greater than or equal to 0 in input.yaml" - if runOverride: - pass - elif var1 != None and var1 > 1 and (var2 or var3) and var4: - user_input = input("\nWARNING: Injecting into the same cycle(index), bit multiple times "+ - "is redundant as it would yield the same result."+ - "\nTo turn off this warning, please see Readme "+ - "for kernel mode.\nDo you wish to continue anyway? (Y/N)\n ") - if user_input.upper() =="Y": +def checkValues(key, val, var1=None, var2=None, var3=None, var4=None): + # preliminary input checking for fi options + # also checks for fi_bit usage by non-kernel users + # optional var# are used for fi_bit's case only + if key == "run_number": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) > 0, key + " must be greater than 0 in input.yaml" + + elif key == "fi_type": pass - else: - exit(1) - elif key == 'fi_random_seed': - assert isinstance(val, int)==True, key+" must be an integer in input.yaml" - assert int(val) >= 0, key+" must be greater than or equal to 0 in input.yaml" + elif key == "fi_num_bits": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) >= 1, key + " must be greater than or equal to 1 in input.yaml" + + elif key == "window_len": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) >= 0, ( + key + " must be greater than or equal to zero in input.yaml" + ) + + elif key == "fi_max_multiple": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) > 0, key + " must be greater than zero in input.yaml" + assert int(val) <= int(fi_max_multiple_default), ( + key + + " must be smaller than or equal to " + + str(fi_max_multiple_default) + + " in input.yaml" + ) + + elif key == "window_len_multiple": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) > 0, key + " must be greater than zero in input.yaml" + elif key == "window_len_multiple_startindex": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) > 0, key + " must be greater than zero in input.yaml" + elif key == "window_len_multiple_endindex": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) > 0, key + " must be greater than zero in input.yaml" + + elif key == "fi_cycle": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) > 0, key + " must be greater than 0 in input.yaml" + assert int(val) <= int(totalcycles), ( + key + + " must be less than or equal to " + + totalcycles.strip() + + " in input.yaml" + ) + + elif key == "fi_index": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) >= 0, key + " must be greater than or equal to 0 in input.yaml" + + elif key == "fi_reg_index": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) >= 0, key + " must be greater than or equal to 0 in input.yaml" + + elif key == "fi_bit": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) >= 0, key + " must be greater than or equal to 0 in input.yaml" + if runOverride: + pass + elif var1 is not None and var1 > 1 and (var2 or var3) and var4: + user_input = input( + "\nWARNING: Injecting into the same cycle(index), bit multiple times " + + "is redundant as it would yield the same result." + + "\nTo turn off this warning, please see Readme " + + "for kernel mode.\nDo you wish to continue anyway? (Y/N)\n " + ) + if user_input.upper() == "Y": + pass + else: + sys.exit(1) + + elif key == "fi_random_seed": + assert isinstance(val, int), key + " must be an integer in input.yaml" + assert int(val) >= 0, key + " must be greater than or equal to 0 in input.yaml" + ################################################################################ def main(args): - global optionlist, outputfile, totalcycles,run_id, return_codes - global defaultTimeout - - parseArgs(args) - checkInputYaml() - config() - - # get total num of cycles - readCycles() - storeInputFiles() - - #Set up each config file and its corresponding run_number - try: - rOpt = doc["runOption"] - except: - print("ERROR: Please include runOption in input.yaml.") - exit(1) - - if not os.path.isfile(fi_exe): - print("ERROR: The executable "+ fi_exe+" does not exist.") - print("Please build the executables with create-executables.\n") - exit(1) - else: - print("======Fault Injection======") - for ii, run in enumerate(rOpt): - # Maintain a dict of all return codes received and print summary at end - return_codes = {} - - # Put an empty line between configs - if ii > 0: - print("") - print("---FI Config #"+str(ii)+"---") - - if "numOfRuns" not in run["run"]: - print("ERROR: Must include a run number per fi config in input.yaml.") - exit(1) - - if "timeOut" in run["run"]: - timeout = int(run["run"]["timeOut"]) - assert timeout > 0, "The timeOut option must be greater than 0" - else: - timeout = defaultTimeout - print("Run with default timeout " + str(timeout)) - - run_number=run["run"]["numOfRuns"] - checkValues("run_number", run_number) - - # check for verbosity option, set at the FI run level - if "verbose" in run["run"]: - options["verbose"] = run["run"]["verbose"] - - # reset all configurations - if 'fi_type' in locals(): - del fi_type - if 'fi_cycle' in locals(): - del fi_cycle - if 'fi_index' in locals(): - del fi_index - if 'fi_reg_index' in locals(): - del fi_reg_index - if 'fi_bit' in locals(): - del fi_bit - ##======== Add number of corrupted bits QINING @MAR 13th======== - if 'fi_num_bits' in locals(): - del fi_num_bits - ##============================================================== - ##======== Add second corrupted regs QINING @MAR 27th=========== - if 'window_len' in locals(): - del window_len - ##============================================================== - if 'fi_random_seed' in locals(): - del fi_random_seed - ##============================================================== - ##BEHROOZ: Add max number of target locations - if 'fi_max_multiple' in locals(): - del fi_max_multiple - ##============================================================== - ##BEHROOZ: Add multiple corrupted regs - if 'window_len_multiple' in locals(): - del window_len_multiple - if 'window_len_multiple_startindex' in locals(): - del window_len_multiple_startindex - if 'window_len_multiple_endindex' in locals(): - del window_len_multiple_endindex - ##============================================================== - #write new fi config file according to input.yaml - if "fi_type" in run["run"]: - fi_type=run["run"]["fi_type"] - if fi_type == "SoftwareFault" or fi_type == "AutoInjection" or fi_type == "Automated": - try: - cOpt = doc["compileOption"] - injectorname = cOpt["instSelMethod"][0]["customInstselector"]["include"][0] - except: - print("\n\nERROR: Cannot extract fi_type from instSelMethod. Please check the customInstselector field in input.yaml\n") - else: - fi_type = injectorname - checkValues("fi_type",fi_type) - ##======== Add number of corrupted bits QINING @MAR 13th======== - if "fi_num_bits" in run["run"]: - fi_num_bits=run["run"]["fi_num_bits"] - checkValues("fi_num_bits", fi_num_bits) - ##============================================================== - ##======== Add second corrupted regs QINING @MAR 27th=========== - if 'window_len' in run["run"]: - window_len=run["run"]["window_len"] - checkValues("window_len", window_len) - ##============================================================== - ##BEHROOZ: Add max number of target locations - if 'fi_max_multiple' in run["run"]: - fi_max_multiple=run["run"]["fi_max_multiple"] - checkValues("fi_max_multiple", fi_max_multiple) - if ('fi_max_multiple' in locals()) and 'window_len' in locals(): - print(("\nERROR: window_len and fi_max_multiple cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - ##============================================================== - ##BEHROOZ: Add multiple corrupted regs - if 'window_len_multiple' in run["run"]: - window_len_multiple=run["run"]["window_len_multiple"] - checkValues("window_len_multiple", window_len_multiple) - if ('window_len_multiple' in locals()): - if ('window_len' in run["run"]): - print(("\nERROR: window_len and window_len_multiple cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - elif ('window_len_multiple_startindex' in run["run"]): - print(("\nERROR: window_len_multiple_startindex and window_len_multiple cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - elif ('window_len_multiple_endindex' in run["run"]): - print(("\nERROR: window_len_multiple_endindex and window_len_multiple cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - if 'window_len_multiple_startindex' in run["run"]: - window_len_multiple_startindex=run["run"]["window_len_multiple_startindex"] - checkValues("window_len_multiple_startindex", window_len_multiple_startindex) - if ('window_len_multiple_startindex' in locals()): - if ('window_len' in run["run"]): - print(("\nERROR: window_len and window_len_multiple_startindex cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - elif ('window_len_multiple' in run["run"]): - print(("\nERROR: window_len_multiple_startindex and window_len_multiple cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - elif ('window_len_multiple_endindex' not in run["run"]): - print(("\nERROR: window_len_multiple_startindex should come with window_len_multiple_endindex." - " Please specify both.")) - exit(1) - if 'window_len_multiple_endindex' in run["run"]: - window_len_multiple_endindex=run["run"]["window_len_multiple_endindex"] - checkValues("window_len_multiple_endindex", window_len_multiple_endindex) - if ('window_len_multiple_endindex' in locals()): - if('window_len' in run["run"]): - print(("\nERROR: window_len and window_len_multiple_endindex cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - elif('window_len_multiple' in run["run"]): - print(("\nERROR: window_len_multiple_endindex and window_len_multiple cannot be specified" - " at the same time in the input.yaml file. Please choose one.")) - exit(1) - elif('window_len_multiple_startindex' not in run["run"]): - print(("\nERROR: window_len_multiple_endindex should come with window_len_multiple_startindex." - " Please specify both.")) - exit(1) - ##============================================================== - if "fi_cycle" in run["run"]: - fi_cycle=run["run"]["fi_cycle"] - checkValues("fi_cycle",fi_cycle) - if "fi_index" in run["run"]: - fi_index=run["run"]["fi_index"] - checkValues("fi_index",fi_index) - if "fi_reg_index" in run["run"]: - fi_reg_index=run["run"]["fi_reg_index"] - checkValues("fi_reg_index",fi_reg_index) - if "fi_bit" in run["run"]: - fi_bit=run["run"]["fi_bit"] - checkValues("fi_bit",fi_bit,run_number,fi_cycle,fi_index,fi_reg_index) - if "fi_random_seed" in run["run"]: - fi_random_seed=run["run"]["fi_random_seed"] - checkValues("fi_random_seed",fi_random_seed) - - if ('fi_cycle' not in locals()) and 'fi_index' in locals(): - print(("\nINFO: You choose to inject faults based on LLFI index, " - "this will inject into every runtime instruction whose LLFI " - "index is %d\n" % fi_index)) - ##BEHROOZ: - if ('window_len_multiple' in locals() or 'window_len_multiple_startindex' in locals() or 'window_len_multiple_endindex' in locals()): - if('fi_max_multiple' not in locals()): - print(("\nINFO: You choose a window length for multiple bit-flip injection, " - "however you have not specified the maximum number of locations." - " Thus, the maximum number of locations will be chosen as " +str(fi_max_multiple_default)+ ".\n")) - fi_max_multiple = int(fi_max_multiple_default) - - if ('window_len_multiple' not in locals() and 'window_len_multiple_startindex' not in locals()) and 'fi_max_multiple' in locals(): - print(("\nINFO: You choose the maximum number of multiple bit injection, " - "however you have not specified the window length for multiple bit-flip injection." - " Thus, the window size will be chosen equal to the total number of cycles-1= " - + str(int(totalcycles)-1)+ ".\n")) - window_len_multiple = int(totalcycles) - 1 - ##====================================================== - need_to_calc_fi_cycle = True - if ('fi_cycle' in locals()) or 'fi_index' in locals(): - need_to_calc_fi_cycle = False - - # fault injection - for index in range(0, run_number): - run_id = str(ii)+"-"+str(index) - outputfile = stddir + "/std_outputfile-" + "run-"+run_id - errorfile = errordir + "/errorfile-" + "run-"+run_id - execlist = [fi_exe] - - if('fi_cycle' not in locals() and 'fi_random_seed' in locals()): - random.seed(fi_random_seed) - - if need_to_calc_fi_cycle: - ##BEHROOZ: I changed the below line to the current one to fix the fi_cycle - fi_cycle = random.randint(1, int(totalcycles)) - ##fi_cycle = random.randint(0, int(totalcycles) - 1) - - ficonfig_File = open("llfi.config.runtime.txt", 'w') - - global fi_ml_stats - if 'fi_cycle' in locals() and len(fi_ml_stats) > 0: - - # Find to which Ml layer this fi_cycle belongs to. - for i in range(0, len(fi_ml_stats)): - if fi_cycle >= fi_ml_stats[i][2] and fi_cycle <= fi_ml_stats[i][3]: - ficonfig_File.write("ml_layer_name="+fi_ml_stats[i][1]+'\n') - ficonfig_File.write("ml_layer_number="+str(fi_ml_stats[i][0])+'\n') - - if 'fi_cycle' in locals(): - ficonfig_File.write("fi_cycle="+str(fi_cycle)+'\n') - elif 'fi_index' in locals(): - ficonfig_File.write("fi_index="+str(fi_index)+'\n') - - if 'fi_type' in locals(): - ficonfig_File.write("fi_type="+fi_type+'\n') - if 'fi_reg_index' in locals(): - ficonfig_File.write("fi_reg_index="+str(fi_reg_index)+'\n') - if 'fi_bit' in locals(): - ficonfig_File.write("fi_bit="+str(fi_bit)+'\n') - ##======== Add number of corrupted bits QINING @MAR 13th======== - if 'fi_num_bits' in locals(): - ficonfig_File.write("fi_num_bits="+str(fi_num_bits)+'\n') - ##============================================================== - ##======== Add second corrupted regs QINING @MAR 27th=========== - if 'window_len' in locals(): - ##BEHROOZ: I changed the below line to the current one to fix the fi_cycle - fi_second_cycle = min(fi_cycle + random.randint(1, int(window_len)), int(totalcycles)) - #fi_second_cycle = min(fi_cycle + random.randint(1, int(window_len)), int(totalcycles) - 1) - ficonfig_File.write("fi_second_cycle="+str(fi_second_cycle)+'\n') - ##================================================================== - ##BEHROOZ: Add max number of target locations - if ('fi_max_multiple' in locals()): - win_start_index = 1 - win_end_index = 1 - if('window_len_multiple' in locals()): - win_end_index = int(window_len_multiple) - elif('window_len_multiple_startindex' in locals() and 'window_len_multiple_endindex' in locals()): - win_start_index = window_len_multiple_startindex - win_end_index = window_len_multiple_endindex - if(win_start_index > win_end_index): - print(("\nERROR: In the yaml file, the window_len_multiple_startindex cannot be bigger than window_len_multiple_endindex!")) - exit(1) - #The line below has been substituted with the one below it. This way the maximum number injection is not selected randomly and is - #equal to the value specified by the user - ##selected_num_of_injection = random.randint(1, int(fi_max_multiple)) - ficonfig_File.write("fi_max_multiple="+str(fi_max_multiple)+'\n') - selected_num_of_injection = fi_max_multiple - ##======The -1 here is because we have already selected the first location by choosing the fi-cycle - ##===== and here we are looking for the remaining cycles.================= - fi_next_cycle = fi_cycle - for index_multiple in range(1, int(selected_num_of_injection)): - fi_next_cycle = min(fi_next_cycle + random.randint(win_start_index, win_end_index), int(totalcycles)) - ficonfig_File.write("fi_next_cycle="+str(fi_next_cycle)+'\n') - if fi_next_cycle == int(totalcycles): - break - ##================================================================== - ficonfig_File.close() - - # print run index before executing. Comma removes newline for prettier - # formatting - execlist.extend(optionlist) - ret = execute(execlist, timeout) - if ret == "timed-out": - error_File = open(errorfile, 'w') - error_File.write("Program hang\n") - error_File.close() - elif int(ret) < 0: - error_File = open(errorfile, 'w') - error_File.write("Program crashed, terminated by the system, return code " + ret + '\n') - error_File.close() - elif int(ret) > 0: - error_File = open(errorfile, 'w') - error_File.write("Program crashed, terminated by itself, return code " + ret + '\n') - error_File.close() - - # Print updates, print the number of injections finished - print_progressbar(index+1, run_number) - - #print_progressbar(run_number, run_number) - print("") # progress bar needs a newline after 100% reached - # Print summary - if options["verbose"]: - print("========== SUMMARY ==========") - print("Return codes: (code:\toccurance)") - for r in list(return_codes.keys()): - print((" %3s: %5d" % (str(r), return_codes[r]))) + global outputfile, run_id, return_codes + + parseArgs(args) + checkInputYaml() + config() + + # get total num of cycles + readCycles() + storeInputFiles() + + # Set up each config file and its corresponding run_number + try: + rOpt = doc["runOption"] + except Exception: + print("ERROR: Please include runOption in input.yaml.") + sys.exit(1) + + if not os.path.isfile(fi_exe): + print("ERROR: The executable " + fi_exe + " does not exist.") + print("Please build the executables with create-executables.\n") + sys.exit(1) + else: + print("======Fault Injection======") + for ii, run in enumerate(rOpt): + # Maintain a dict of all return codes received and print summary at end + return_codes = {} + + # Put an empty line between configs + if ii > 0: + print("") + print("---FI Config #" + str(ii) + "---") + + if "numOfRuns" not in run["run"]: + print("ERROR: Must include a run number per fi config in input.yaml.") + sys.exit(1) + + if "timeOut" in run["run"]: + timeout = int(run["run"]["timeOut"]) + assert timeout > 0, "The timeOut option must be greater than 0" + else: + timeout = defaultTimeout + print("Run with default timeout " + str(timeout)) + + run_number = run["run"]["numOfRuns"] + checkValues("run_number", run_number) + + # check for verbosity option, set at the FI run level + if "verbose" in run["run"]: + options["verbose"] = run["run"]["verbose"] + + # reset all configurations + if "fi_type" in locals(): + del fi_type + if "fi_cycle" in locals(): + del fi_cycle + if "fi_index" in locals(): + del fi_index + if "fi_reg_index" in locals(): + del fi_reg_index + if "fi_bit" in locals(): + del fi_bit + if "fi_num_bits" in locals(): + del fi_num_bits + if "window_len" in locals(): + del window_len + if "fi_random_seed" in locals(): + del fi_random_seed + if "fi_max_multiple" in locals(): + del fi_max_multiple + if "window_len_multiple" in locals(): + del window_len_multiple + if "window_len_multiple_startindex" in locals(): + del window_len_multiple_startindex + if "window_len_multiple_endindex" in locals(): + del window_len_multiple_endindex + ##============================================================== + # write new fi config file according to input.yaml + if "fi_type" in run["run"]: + fi_type = run["run"]["fi_type"] + if ( + fi_type == "AutoInjection" + or fi_type == "Automated" + ): + try: + cOpt = doc["compileOption"] + injectorname = cOpt["instSelMethod"][0]["customInstselector"][ + "include" + ][0] + except Exception: + print( + "\n\nERROR: Cannot extract fi_type from instSelMethod. Please check the customInstselector field in input.yaml\n" + ) + else: + fi_type = injectorname + checkValues("fi_type", fi_type) + if "fi_num_bits" in run["run"]: + fi_num_bits = run["run"]["fi_num_bits"] + checkValues("fi_num_bits", fi_num_bits) + if "window_len" in run["run"]: + window_len = run["run"]["window_len"] + checkValues("window_len", window_len) + if "fi_max_multiple" in run["run"]: + fi_max_multiple = run["run"]["fi_max_multiple"] + checkValues("fi_max_multiple", fi_max_multiple) + if ("fi_max_multiple" in locals()) and "window_len" in locals(): + print( + ( + "\nERROR: window_len and fi_max_multiple cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + if "window_len_multiple" in run["run"]: + window_len_multiple = run["run"]["window_len_multiple"] + checkValues("window_len_multiple", window_len_multiple) + if "window_len_multiple" in locals(): + if "window_len" in run["run"]: + print( + ( + "\nERROR: window_len and window_len_multiple cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + elif "window_len_multiple_startindex" in run["run"]: + print( + ( + "\nERROR: window_len_multiple_startindex and window_len_multiple cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + elif "window_len_multiple_endindex" in run["run"]: + print( + ( + "\nERROR: window_len_multiple_endindex and window_len_multiple cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + if "window_len_multiple_startindex" in run["run"]: + window_len_multiple_startindex = run["run"][ + "window_len_multiple_startindex" + ] + checkValues( + "window_len_multiple_startindex", window_len_multiple_startindex + ) + if "window_len_multiple_startindex" in locals(): + if "window_len" in run["run"]: + print( + ( + "\nERROR: window_len and window_len_multiple_startindex cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + elif "window_len_multiple" in run["run"]: + print( + ( + "\nERROR: window_len_multiple_startindex and window_len_multiple cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + elif "window_len_multiple_endindex" not in run["run"]: + print( + ( + "\nERROR: window_len_multiple_startindex should come with window_len_multiple_endindex." + " Please specify both." + ) + ) + sys.exit(1) + if "window_len_multiple_endindex" in run["run"]: + window_len_multiple_endindex = run["run"][ + "window_len_multiple_endindex" + ] + checkValues( + "window_len_multiple_endindex", window_len_multiple_endindex + ) + if "window_len_multiple_endindex" in locals(): + if "window_len" in run["run"]: + print( + ( + "\nERROR: window_len and window_len_multiple_endindex cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + elif "window_len_multiple" in run["run"]: + print( + ( + "\nERROR: window_len_multiple_endindex and window_len_multiple cannot be specified" + " at the same time in the input.yaml file. Please choose one." + ) + ) + sys.exit(1) + elif "window_len_multiple_startindex" not in run["run"]: + print( + ( + "\nERROR: window_len_multiple_endindex should come with window_len_multiple_startindex." + " Please specify both." + ) + ) + sys.exit(1) + ##============================================================== + if "fi_cycle" in run["run"]: + fi_cycle = run["run"]["fi_cycle"] + checkValues("fi_cycle", fi_cycle) + if "fi_index" in run["run"]: + fi_index = run["run"]["fi_index"] + checkValues("fi_index", fi_index) + if "fi_reg_index" in run["run"]: + fi_reg_index = run["run"]["fi_reg_index"] + checkValues("fi_reg_index", fi_reg_index) + if "fi_bit" in run["run"]: + fi_bit = run["run"]["fi_bit"] + checkValues( + "fi_bit", fi_bit, run_number, fi_cycle, fi_index, fi_reg_index + ) + if "fi_random_seed" in run["run"]: + fi_random_seed = run["run"]["fi_random_seed"] + checkValues("fi_random_seed", fi_random_seed) + + if ("fi_cycle" not in locals()) and "fi_index" in locals(): + print( + ( + "\nINFO: You choose to inject faults based on LLFI index, " + "this will inject into every runtime instruction whose LLFI " + "index is %d\n" % fi_index + ) + ) + if ( + "window_len_multiple" in locals() + or "window_len_multiple_startindex" in locals() + or "window_len_multiple_endindex" in locals() + ): + if "fi_max_multiple" not in locals(): + print( + ( + "\nINFO: You choose a window length for multiple bit-flip injection, " + "however you have not specified the maximum number of locations." + " Thus, the maximum number of locations will be chosen as " + + str(fi_max_multiple_default) + + ".\n" + ) + ) + fi_max_multiple = int(fi_max_multiple_default) + + if ( + "window_len_multiple" not in locals() + and "window_len_multiple_startindex" not in locals() + ) and "fi_max_multiple" in locals(): + print( + ( + "\nINFO: You choose the maximum number of multiple bit injection, " + "however you have not specified the window length for multiple bit-flip injection." + " Thus, the window size will be chosen equal to the total number of cycles-1= " + + str(int(totalcycles) - 1) + + ".\n" + ) + ) + window_len_multiple = int(totalcycles) - 1 + ##====================================================== + need_to_calc_fi_cycle = True + if ("fi_cycle" in locals()) or "fi_index" in locals(): + need_to_calc_fi_cycle = False + + # fault injection + for index in range(0, run_number): + run_id = str(ii) + "-" + str(index) + outputfile = stddir + "/std_outputfile-" + "run-" + run_id + errorfile = errordir + "/errorfile-" + "run-" + run_id + execlist = [fi_exe] + + if "fi_cycle" not in locals() and "fi_random_seed" in locals(): + random.seed(fi_random_seed) + + if need_to_calc_fi_cycle: + fi_cycle = random.randint(1, int(totalcycles)) + + with open("llfi.config.runtime.txt", "w") as ficonfig_File: + if "fi_cycle" in locals() and len(fi_ml_stats) > 0: + # Find to which ML layer this fi_cycle belongs to. + for i in range(0, len(fi_ml_stats)): + if ( + fi_cycle >= fi_ml_stats[i][2] + and fi_cycle <= fi_ml_stats[i][3] + ): + ficonfig_File.write( + "ml_layer_name=" + fi_ml_stats[i][1] + "\n" + ) + ficonfig_File.write( + "ml_layer_number=" + str(fi_ml_stats[i][0]) + "\n" + ) + + if "fi_cycle" in locals(): + ficonfig_File.write("fi_cycle=" + str(fi_cycle) + "\n") + elif "fi_index" in locals(): + ficonfig_File.write("fi_index=" + str(fi_index) + "\n") + + if "fi_type" in locals(): + ficonfig_File.write("fi_type=" + fi_type + "\n") + if "fi_reg_index" in locals(): + ficonfig_File.write("fi_reg_index=" + str(fi_reg_index) + "\n") + if "fi_bit" in locals(): + ficonfig_File.write("fi_bit=" + str(fi_bit) + "\n") + if "fi_num_bits" in locals(): + ficonfig_File.write("fi_num_bits=" + str(fi_num_bits) + "\n") + if "window_len" in locals(): + fi_second_cycle = min( + fi_cycle + random.randint(1, int(window_len)), + int(totalcycles), + ) + ficonfig_File.write( + "fi_second_cycle=" + str(fi_second_cycle) + "\n" + ) + if "fi_max_multiple" in locals(): + win_start_index = 1 + win_end_index = 1 + if "window_len_multiple" in locals(): + win_end_index = int(window_len_multiple) + elif ( + "window_len_multiple_startindex" in locals() + and "window_len_multiple_endindex" in locals() + ): + win_start_index = window_len_multiple_startindex + win_end_index = window_len_multiple_endindex + if win_start_index > win_end_index: + print( + ( + "\nERROR: In the yaml file, the window_len_multiple_startindex cannot be bigger than window_len_multiple_endindex!" + ) + ) + sys.exit(1) + ficonfig_File.write( + "fi_max_multiple=" + str(fi_max_multiple) + "\n" + ) + selected_num_of_injection = fi_max_multiple + # The first fi_cycle location is already selected; find remaining cycles. + fi_next_cycle = fi_cycle + for _ in range(1, int(selected_num_of_injection)): + fi_next_cycle = min( + fi_next_cycle + + random.randint(win_start_index, win_end_index), + int(totalcycles), + ) + ficonfig_File.write( + "fi_next_cycle=" + str(fi_next_cycle) + "\n" + ) + if fi_next_cycle == int(totalcycles): + break + + # print run index before executing. Comma removes newline for prettier + # formatting + execlist.extend(optionlist) + ret = execute(execlist, timeout) + if ret == "timed-out": + with open(errorfile, "w") as error_File: + error_File.write("Program hang\n") + elif int(ret) < 0: + with open(errorfile, "w") as error_File: + error_File.write( + "Program crashed, terminated by the system, return code " + + ret + + "\n" + ) + elif int(ret) > 0: + with open(errorfile, "w") as error_File: + error_File.write( + "Program crashed, terminated by itself, return code " + + ret + + "\n" + ) + + # Print updates, print the number of injections finished + print_progressbar(index + 1, run_number) + + # print_progressbar(run_number, run_number) + print("") # progress bar needs a newline after 100% reached + # Print summary + if options["verbose"]: + print("========== SUMMARY ==========") + print("Return codes: (code:\toccurance)") + for r in list(return_codes.keys()): + print((" %3s: %5d" % (str(r), return_codes[r]))) + ################################################################################ -if __name__=="__main__": - if len(sys.argv) == 1: - usage('Must provide the fault injection executable and its options') - exit(1) - main(sys.argv[1:]) +if __name__ == "__main__": + if len(sys.argv) == 1: + usage("Must provide the fault injection executable and its options") + sys.exit(1) + main(sys.argv[1:]) diff --git a/bin/instrument.py b/bin/instrument.py index bf0f9d5b..b722fe0e 100755 --- a/bin/instrument.py +++ b/bin/instrument.py @@ -30,16 +30,15 @@ import subprocess script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) +sys.path.append(os.path.join(script_path, "../config")) import llvm_paths - optbin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/opt") llcbin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/llc") llvmgcc = os.path.join(llvm_paths.LLVM_GXX_BIN_DIR, "clang") llvmgxx = os.path.join(llvm_paths.LLVM_GXX_BIN_DIR, "clang++") llfilinklib = os.path.join(script_path, "../runtime_lib") -defaultlinklibs = ['-lpthread'] +defaultlinklibs = ["-lpthread"] prog = os.path.basename(sys.argv[0]) # basedir is assigned in parseArgs(args) basedir = "" @@ -49,396 +48,521 @@ # shutil.rmtree(llfibd) if sys.platform == "linux" or sys.platform == "linux2": - llfilib = os.path.join(script_path, "../llvm_passes/llfi-passes.so") + llfilib = os.path.join(script_path, "../llvm_passes/llfi-passes.so") elif sys.platform == "darwin": - llfilib = os.path.join(script_path, "../llvm_passes/llfi-passes.dylib") + llfilib = os.path.join(script_path, "../llvm_passes/llfi-passes.dylib") else: - print("ERROR: LLFI does not support platform " + sys.platform + ".") - exit(1) + print("ERROR: LLFI does not support platform " + sys.platform + ".") + sys.exit(1) options = { - "dir": "llfi", - "source": "", - "L": [], - "l": [], - "readable": False, - "verbose": False, - "IRonly": False, - "genDotGraph": False, - "useMLSpecificRT": False, - "enableMLFIStats" : False, + "dir": "llfi", + "source": "", + "L": [], + "l": [], + "readable": False, + "verbose": False, + "IRonly": False, + "genDotGraph": False, + "useMLSpecificRT": False, + "enableMLFIStats": False, } -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) def verbosePrint(msg, verbose): - if verbose: - print(msg) + if verbose: + print(msg) def parseArgs(args): - global options - argid = 0 - while argid < len(args): - arg = args[argid] - if arg.startswith("-"): - if arg == "--dir": - if options["dir"] != "llfi": - usage("Duplicated argument: " + arg) - argid += 1 - options["dir"] = args[argid].rstrip('/') - elif arg == "-L": + argid = 0 + while argid < len(args): + arg = args[argid] + if arg.startswith("-"): + if arg == "--dir": + if options["dir"] != "llfi": + usage("Duplicated argument: " + arg) + argid += 1 + options["dir"] = args[argid].rstrip("/") + elif arg == "-L": + argid += 1 + options["L"].append( + os.path.abspath(os.path.join(os.getcwd(), args[argid])) + ) + elif arg.startswith("-l"): + options["l"].append(arg[2:]) + elif arg == "--readable": + options["readable"] = True + elif arg == "--verbose": + options["verbose"] = True + elif arg == "--IRonly": + options["IRonly"] = True + elif arg == "--use-ml-specific-rt": + options["useMLSpecificRT"] = True + elif arg == "--enable-ML-FI-stats": + options["enableMLFIStats"] = True + elif arg == "--help" or arg == "-h": + usage() + else: + usage("Invalid argument: " + arg) + else: + if options["source"] != "": + usage("More than one source files are specified") + options["source"] = os.path.abspath(os.path.join(os.getcwd(), arg)) + basedir = os.path.dirname(options["source"]) + if basedir != os.path.abspath(os.getcwd()): + print("Change directory to:", basedir) + os.chdir(basedir) argid += 1 - options["L"].append(os.path.abspath(os.path.join(os.getcwd(), args[argid]))) - elif arg.startswith("-l"): - options["l"].append(arg[2:]) - elif arg == "--readable": - options["readable"] = True - elif arg == "--verbose": - options["verbose"] = True - elif arg == "--IRonly": - options["IRonly"] = True - elif arg == "--use-ml-specific-rt": - options["useMLSpecificRT"] = True - elif arg == "--enable-ML-FI-stats": - options["enableMLFIStats"] = True - elif arg == "--help" or arg == "-h": - usage() - else: - usage("Invalid argument: " + arg) - else: - if options["source"] != "": - usage("More than one source files are specified") - options["source"] = os.path.abspath(os.path.join(os.getcwd(), arg)) - basedir = os.path.dirname(options["source"]) - if basedir != os.path.abspath(os.getcwd()): - print("Change directory to:", basedir) - os.chdir(basedir) - argid += 1 - - if options["source"] == "": - usage("No input IR file specified") - - if '/' in options["dir"]: - usage("Cannot specify embedded directories for --dir") - else: - srcpath = os.path.dirname(options["source"]) - fullpath = os.path.join(srcpath, options["dir"]) - if os.path.exists(fullpath): - usage(options["dir"] + " already exists under " + srcpath + \ - ", you can either specify a different directory for --dir or " +\ - "remove " + options["dir"] + " from " + srcpath) + + if options["source"] == "": + usage("No input IR file specified") + + if "/" in options["dir"]: + usage("Cannot specify embedded directories for --dir") else: - try: - os.mkdir(fullpath) - options["dir"] = fullpath - except: - usage("Unable to create a directory named " + options["dir"] +\ - " under " + srcpath) + srcpath = os.path.dirname(options["source"]) + fullpath = os.path.join(srcpath, options["dir"]) + if os.path.exists(fullpath): + usage( + options["dir"] + + " already exists under " + + srcpath + + ", you can either specify a different directory for --dir or " + + "remove " + + options["dir"] + + " from " + + srcpath + ) + else: + try: + os.mkdir(fullpath) + options["dir"] = fullpath + except Exception: + usage( + "Unable to create a directory named " + + options["dir"] + + " under " + + srcpath + ) def checkInputYaml(): - #Check for input.yaml's presence - global cOpt - srcpath = os.path.dirname(options["source"]) - try: - f = open(os.path.join(srcpath, 'input.yaml'), 'r') - except: - print("ERROR: No input.yaml file in the %s directory." % srcpath) - os.rmdir(options["dir"]) - exit(1) - - #Check for input.yaml's correct formmating - try: - doc = yaml.safe_load(f) - f.close() - verbosePrint(yaml.dump(doc), options["verbose"]) - except: - print("Error: input.yaml is not formatted in proper YAML (reminder: use spaces, not tabs)") - os.rmdir(options["dir"]) - exit(1) - - #Check for compileOption in input.yaml - try: - cOpt = doc["compileOption"] - except: - os.rmdir(options["dir"]) - print("ERROR: Please include compileOptions in input.yaml.") - exit(1) - + # Check for input.yaml's presence + global cOpt + srcpath = os.path.dirname(options["source"]) + try: + with open(os.path.join(srcpath, "input.yaml"), "r") as f: + doc = yaml.safe_load(f) + verbosePrint(yaml.dump(doc), options["verbose"]) + except OSError: + print("ERROR: No input.yaml file in the %s directory." % srcpath) + os.rmdir(options["dir"]) + sys.exit(1) + except Exception: + print( + "Error: input.yaml is not formatted in proper YAML (reminder: use spaces, not tabs)" + ) + os.rmdir(options["dir"]) + sys.exit(1) + + # Check for compileOption in input.yaml + try: + cOpt = doc["compileOption"] + except KeyError: + os.rmdir(options["dir"]) + print("ERROR: Please include compileOptions in input.yaml.") + sys.exit(1) ################################################################################ def execCompilation(execlist): - verbosePrint(' '.join(execlist), options["verbose"]) - p = subprocess.Popen(execlist) - p.wait() - return p.returncode + verbosePrint(" ".join(execlist), options["verbose"]) + p = subprocess.Popen(execlist) + p.wait() + return p.returncode + ################################################################################ def readCompileOption(): - global compileOptions - - ###Instruction selection method - if "instSelMethod" not in cOpt: - print(("\n\nERROR: Please include an 'instSelMethod' key value pair under compileOption in input.yaml.\n")) - exit(1) - else: - compileOptions = [] - validMethods = ["insttype", "funcname", "customInstselector"] - # Generate list of instruction selection methods - # TODO: Generalize and document - instSelMethod = cOpt["instSelMethod"] - for method in instSelMethod: - methodName = list(method.keys())[0] - if methodName not in validMethods: - print ("\n\nERROR: Unknown instruction selection method in input.yaml.\n") - exit(1) - - #Select by instruction type - if methodName == "insttype" or methodName == "funcname": - compileOptions.append("-%s" % (str(methodName))) - #Select by custom instruction - elif methodName == "customInstselector": - compileOptions = ['-custominstselector'] - - # Ensure that 'include' is specified at least - # TODO: This isn't a very extendible way of doing this. - if "include" not in method[methodName]: - print(("\n\nERROR: An 'include' list must be present for the %s method in input.yaml.\n" % (methodName))) - exit(1) - - # Parse all options for current method - custom_instselector_defined = False - for attr in list(method[methodName].keys()): - if(attr == "include" or attr == "exclude"): - prefix = "-%s" % (str(attr)) - if methodName == "insttype": - prefix += "inst=" - elif methodName == "funcname": - prefix += "func=" + global compileOptions + + ###Instruction selection method + if "instSelMethod" not in cOpt: + print( + ( + "\n\nERROR: Please include an 'instSelMethod' key value pair under compileOption in input.yaml.\n" + ) + ) + sys.exit(1) + else: + compileOptions = [] + validMethods = ["insttype", "funcname", "customInstselector"] + # Generate list of instruction selection methods + # TODO: Generalize and document + instSelMethod = cOpt["instSelMethod"] + for method in instSelMethod: + methodName = list(method.keys())[0] + if methodName not in validMethods: + print( + "\n\nERROR: Unknown instruction selection method in input.yaml.\n" + ) + sys.exit(1) + + # Select by instruction type + if methodName == "insttype" or methodName == "funcname": + compileOptions.append("-%s" % (str(methodName))) + # Select by custom instruction elif methodName == "customInstselector": - prefix = "-fiinstselectorname=" - # For customInstselector, only one instruction selector is allowed - if custom_instselector_defined == True: - print("\nERROR: '\'instrument\' only support one customInstselector included in input.yaml.") - print("To apply a list of fault models/failure modes, please use \'batchinstrument\'") - exit(1) - else: - custom_instselector_defined = True - else: # add the ability to give custom options here? - pass - # Generate list of options for attribute - opts = [prefix + opt for opt in method[methodName][attr]] - compileOptions.extend(opts) - elif(attr == "options"): - opts = [opt for opt in method[methodName]["options"]] - compileOptions.extend(opts) - - ###Register selection method - if "regSelMethod" not in cOpt: - print(("\n\nERROR: Please include an 'regSelMethod' key value pair under compileOption in input.yaml.\n")) - exit(1) - else: - #Select by register location - if cOpt["regSelMethod"] == 'regloc': - compileOptions.append('-regloc') - if "regloc" not in cOpt: - print(("\n\nERROR: An 'regloc' key value pair must be present for the regloc method in input.yaml.\n")) - exit(1) - else: - compileOptions.append('-'+cOpt["regloc"]) - - #Select by custom register - elif cOpt["regSelMethod"] == 'customregselector': - compileOptions.append('-customregselector') - if "customRegSelector" not in cOpt: - print(("\n\nERROR: An 'customRegSelector' key value pair must be present for the customregselector method in input.yaml.\n")) - exit(1) - else: - if cOpt["customRegSelector"] == "SoftwareFault" or cOpt["customRegSelector"] == "Automatic": - ## replace the Automatic tag with the customInstSelector name - try: - regselectorname = cOpt["instSelMethod"][0]["customInstselector"]["include"][0] - except: - print("\n\nERROR: Cannot extract customRegSelector from instSelMethod. Please check the customInstselector field in input.yaml\n") + compileOptions = ["-custominstselector"] + + # Ensure that 'include' is specified at least + # TODO: This isn't a very extendible way of doing this. + if "include" not in method[methodName]: + print( + ( + "\n\nERROR: An 'include' list must be present for the %s method in input.yaml.\n" + % (methodName) + ) + ) + sys.exit(1) + + # Parse all options for current method + custom_instselector_defined = False + for attr in list(method[methodName].keys()): + if attr == "include" or attr == "exclude": + prefix = "-%s" % (str(attr)) + if methodName == "insttype": + prefix += "inst=" + elif methodName == "funcname": + prefix += "func=" + elif methodName == "customInstselector": + prefix = "-fiinstselectorname=" + # For customInstselector, only one instruction selector is allowed + if custom_instselector_defined: + print( + "\nERROR: ''instrument' only support one customInstselector included in input.yaml." + ) + print( + "To apply a list of fault models/failure modes, please use 'batchinstrument'" + ) + sys.exit(1) + else: + custom_instselector_defined = True + else: # add the ability to give custom options here? + pass + # Generate list of options for attribute + opts = [prefix + opt for opt in method[methodName][attr]] + compileOptions.extend(opts) + elif attr == "options": + opts = [opt for opt in method[methodName]["options"]] + compileOptions.extend(opts) + + ###Register selection method + if "regSelMethod" not in cOpt: + print( + ( + "\n\nERROR: Please include an 'regSelMethod' key value pair under compileOption in input.yaml.\n" + ) + ) + sys.exit(1) + else: + # Select by register location + if cOpt["regSelMethod"] == "regloc": + compileOptions.append("-regloc") + if "regloc" not in cOpt: + print( + ( + "\n\nERROR: An 'regloc' key value pair must be present for the regloc method in input.yaml.\n" + ) + ) + sys.exit(1) + else: + compileOptions.append("-" + cOpt["regloc"]) + + # Select by custom register + elif cOpt["regSelMethod"] == "customregselector": + compileOptions.append("-customregselector") + if "customRegSelector" not in cOpt: + print( + ( + "\n\nERROR: An 'customRegSelector' key value pair must be present for the customregselector method in input.yaml.\n" + ) + ) + sys.exit(1) else: - compileOptions.append('-firegselectorname='+regselectorname) - else: - compileOptions.append('-firegselectorname='+cOpt["customRegSelector"]) - if "customRegSelectorOption" in cOpt: - for opt in cOpt["customRegSelectorOption"]: - compileOptions.append(opt) + if ( + cOpt["customRegSelector"] == "Automatic" + ): + ## replace the Automatic tag with the customInstSelector name + try: + regselectorname = cOpt["instSelMethod"][0][ + "customInstselector" + ]["include"][0] + except Exception: + print( + "\n\nERROR: Cannot extract customRegSelector from instSelMethod. Please check the customInstselector field in input.yaml\n" + ) + else: + compileOptions.append("-firegselectorname=" + regselectorname) + else: + compileOptions.append( + "-firegselectorname=" + cOpt["customRegSelector"] + ) + if "customRegSelectorOption" in cOpt: + for opt in cOpt["customRegSelectorOption"]: + compileOptions.append(opt) + + else: + print(("\n\nERROR: Unknown Register selection method in input.yaml.\n")) + sys.exit(1) + + ###Injection Trace selection + if "includeInjectionTrace" in cOpt: + for trace in cOpt["includeInjectionTrace"]: + if trace == "forward": + compileOptions.append("-includeforwardtrace") + elif trace == "backward": + compileOptions.append("-includebackwardtrace") + else: + print( + ( + "\n\nERROR: Invalid value for trace (forward/backward allowed) in input.yaml.\n" + ) + ) + sys.exit(1) + + ###Tracing Proppass + if "tracingPropagation" in cOpt and cOpt["tracingPropagation"]: + print( + ( + "\nWARNING: You enabled 'tracingPropagation' option in input.yaml. " + "The generate executables will be able to output dynamic values for instructions. " + "However, the executables take longer time to execute. If you don't want the trace, " + "please disable the option and re-run %s." % prog + ) + ) + compileOptions.append("-insttracepass") + if "tracingPropagationOption" in cOpt: + if "debugTrace" in cOpt["tracingPropagationOption"]: + if ( + str(cOpt["tracingPropagationOption"]["debugTrace"]).lower() + == "true" + ): + compileOptions.append("-debugtrace") + if "maxTrace" in cOpt["tracingPropagationOption"]: + assert isinstance( + cOpt["tracingPropagationOption"]["maxTrace"], int + ), "maxTrace must be an integer in input.yaml" + assert ( + int(cOpt["tracingPropagationOption"]["maxTrace"]) > 0 + ), "maxTrace must be greater than 0 in input.yaml" + compileOptions.append("-maxtrace") + compileOptions.append(str(cOpt["tracingPropagationOption"]["maxTrace"])) + + ###Dot Graph Generation selection + if "generateCDFG" in cOpt["tracingPropagationOption"]: + if ( + str(cOpt["tracingPropagationOption"]["generateCDFG"]).lower() + == "true" + ): + options["genDotGraph"] = True - else: - print(("\n\nERROR: Unknown Register selection method in input.yaml.\n")) - exit(1) - - ###Injection Trace selection - if "includeInjectionTrace" in cOpt: - for trace in cOpt["includeInjectionTrace"]: - if trace == 'forward': - compileOptions.append('-includeforwardtrace') - elif trace == 'backward': - compileOptions.append('-includebackwardtrace') - else: - print(("\n\nERROR: Invalid value for trace (forward/backward allowed) in input.yaml.\n")) - exit(1) - - ###Tracing Proppass - if "tracingPropagation" in cOpt and cOpt["tracingPropagation"] == True: - print(("\nWARNING: You enabled 'tracingPropagation' option in input.yaml. " - "The generate executables will be able to output dynamic values for instructions. " - "However, the executables take longer time to execute. If you don't want the trace, " - "please disable the option and re-run %s." %prog)) - compileOptions.append('-insttracepass') - if 'tracingPropagationOption' in cOpt: - if "debugTrace" in cOpt["tracingPropagationOption"]: - if(str(cOpt["tracingPropagationOption"]["debugTrace"]).lower() == "true"): - compileOptions.append('-debugtrace') - if "maxTrace" in cOpt["tracingPropagationOption"]: - assert isinstance(cOpt["tracingPropagationOption"]["maxTrace"], int)==True, "maxTrace must be an integer in input.yaml" - assert int(cOpt["tracingPropagationOption"]["maxTrace"])>0, "maxTrace must be greater than 0 in input.yaml" - compileOptions.append('-maxtrace') - compileOptions.append(str(cOpt["tracingPropagationOption"]["maxTrace"])) - - ###Dot Graph Generation selection - if "generateCDFG" in cOpt["tracingPropagationOption"]: - if (str(cOpt["tracingPropagationOption"]["generateCDFG"]).lower() == "true"): - options["genDotGraph"] = True ################################################################################ def _suffixOfIR(): - if options["readable"]: - return ".ll" - else: - return ".bc" - -def compileProg(): - global proffile, fifile, compileOptions, defaultlinklibs - srcbase = os.path.basename(options["source"]) - progbin = os.path.join(options["dir"], srcbase[0 : srcbase.rfind(".")]) - - llfi_indexed_file = progbin + "-llfi_index" - proffile = progbin + "-profiling" - fifile = progbin + "-faultinjection" - tmpfiles = [] - - execlist = [optbin, '-load-pass-plugin', llfilib, '-genllfiindexpass', '-o', - llfi_indexed_file + _suffixOfIR(), options['source']] - if options["readable"]: - execlist.append('-S') - if options["genDotGraph"]: - execlist.append('-dotgraphpass') - retcode = execCompilation(execlist) - - if retcode == 0: - execlist = [optbin, '-load-pass-plugin', llfilib, '-profilingpass'] - execlist2 = ['-o', proffile + _suffixOfIR(), llfi_indexed_file + _suffixOfIR()] - execlist.extend(compileOptions) - execlist.extend(execlist2) if options["readable"]: - execlist.append("-S") - if options["enableMLFIStats"]: - execlist.append("-mlfistats") - retcode = execCompilation(execlist) + return ".ll" + else: + return ".bc" - if retcode == 0: - execlist = [optbin, '-load-pass-plugin', llfilib, '-faultinjectionpass'] - execlist2 = ['-o', fifile + _suffixOfIR(), llfi_indexed_file + _suffixOfIR()] - execlist.extend(compileOptions) - execlist.extend(execlist2) - #print(execlist) + +def compileProg(): + global proffile, fifile + srcbase = os.path.basename(options["source"]) + progbin = os.path.join(options["dir"], srcbase[0 : srcbase.rfind(".")]) + + llfi_indexed_file = progbin + "-llfi_index" + proffile = progbin + "-profiling" + fifile = progbin + "-faultinjection" + tmpfiles = [] + + # Separate pass names from cl::opt flags in compileOptions. + # In LLVM 17+, pass names must go in --passes=, not as -passname arguments. + _PASS_NAMES = {"-insttracepass"} + extra_passes = [opt.lstrip("-") for opt in compileOptions if opt in _PASS_NAMES] + compile_flags = [opt for opt in compileOptions if opt not in _PASS_NAMES] + + # Step 1: assign LLFI indices (+ optional dot graph) + index_passes = "genllfiindexpass" + if options["genDotGraph"]: + index_passes += ",dotgraphpass" + execlist = [ + optbin, + "-load-pass-plugin", + llfilib, + "--passes=" + index_passes, + "-o", + llfi_indexed_file + _suffixOfIR(), + options["source"], + ] if options["readable"]: - execlist.append("-S") + execlist.append("-S") retcode = execCompilation(execlist) - if retcode != 0: - print("\nERROR: there was an error during running the "\ - "instrumentation pass, please follow"\ - " the provided instructions for %s." % prog, file=sys.stderr) - shutil.rmtree(options['dir'], ignore_errors = True) - sys.exit(retcode) - - if not options["IRonly"]: - if retcode == 0: - execlist = [llcbin, '-filetype=obj', '-o', proffile + '.o', proffile + _suffixOfIR()] - tmpfiles.append(proffile + '.o') - retcode = execCompilation(execlist) - if retcode == 0: - execlist = [llcbin, '-filetype=obj', '-o', fifile + '.o', fifile + _suffixOfIR()] - tmpfiles.append(fifile + '.o') - retcode = execCompilation(execlist) - - liblist = list(defaultlinklibs) - for lib_dir in options["L"]: - liblist.extend(["-L", lib_dir]) - for lib in options["l"]: - liblist.append("-l" + lib) - liblist.append("-no-pie") - liblist.append("-Wl,-rpath") - liblist.append(llfilinklib) - if retcode == 0: - execlist = [llvmgcc, '-o', proffile + '.exe', proffile + '.o', '-L'+llfilinklib] - - # Check whether we should use static or dynamic FI RT - execlist.extend(["-lllfi-rt"]) - execlist.extend(liblist) - retcode = execCompilation(execlist) - if retcode != 0: - print("...Error compiling with " + os.path.basename(llvmgcc) + ", trying with " + os.path.basename(llvmgxx) + ".") - execlist[0] = llvmgxx + prof_passes = "profilingpass" + if extra_passes: + prof_passes += "," + ",".join(extra_passes) + execlist = [optbin, "-load-pass-plugin", llfilib, "--passes=" + prof_passes] + execlist2 = ["-o", proffile + _suffixOfIR(), llfi_indexed_file + _suffixOfIR()] + execlist.extend(compile_flags) + execlist.extend(execlist2) + if options["readable"]: + execlist.append("-S") + if options["enableMLFIStats"]: + execlist.append("-mlfistats") retcode = execCompilation(execlist) + if retcode == 0: - execlist = [llvmgcc, '-o', fifile + '.exe', fifile + '.o', '-L'+llfilinklib] - - # Check whether we should use static or dynamic FI RT - if options['useMLSpecificRT']: - execlist.extend(["-lml-lltfi-rt"]) - else: - execlist.extend(["-lllfi-rt"]) - - execlist.extend(liblist) - retcode = execCompilation(execlist) - if retcode != 0: - print("...Error compiling with " + os.path.basename(llvmgcc) + ", trying " + os.path.basename(llvmgxx) + ".") - execlist[0] = llvmgxx + execlist = [optbin, "-load-pass-plugin", llfilib, "--passes=faultinjectionpass"] + execlist2 = ["-o", fifile + _suffixOfIR(), llfi_indexed_file + _suffixOfIR()] + execlist.extend(compile_flags) + execlist.extend(execlist2) + if options["readable"]: + execlist.append("-S") retcode = execCompilation(execlist) - - for tmpfile in tmpfiles: - try: - os.remove(tmpfile) - except: - pass if retcode != 0: - print("\nERROR: there was an error during linking and generating executables,"\ - "Please take %s and %s and generate the executables manually (linking llfi-rt "\ - "in directory %s)." %(proffile + _suffixOfIR(), fifile + _suffixOfIR(), llfilinklib), file=sys.stderr) - sys.exit(retcode) - else: - print("\nSuccess", file=sys.stderr) + print( + "\nERROR: there was an error during running the " + "instrumentation pass, please follow" + " the provided instructions for %s." % prog, + file=sys.stderr, + ) + shutil.rmtree(options["dir"], ignore_errors=True) + sys.exit(retcode) + + if not options["IRonly"]: + if retcode == 0: + execlist = [ + llcbin, + "-filetype=obj", + "-o", + proffile + ".o", + proffile + _suffixOfIR(), + ] + tmpfiles.append(proffile + ".o") + retcode = execCompilation(execlist) + if retcode == 0: + execlist = [ + llcbin, + "-filetype=obj", + "-o", + fifile + ".o", + fifile + _suffixOfIR(), + ] + tmpfiles.append(fifile + ".o") + retcode = execCompilation(execlist) + + liblist = list(defaultlinklibs) + for lib_dir in options["L"]: + liblist.extend(["-L", lib_dir]) + for lib in options["l"]: + liblist.append("-l" + lib) + liblist.append("-no-pie") + liblist.append("-Wl,-rpath") + liblist.append(llfilinklib) + + if retcode == 0: + execlist = [ + llvmgcc, + "-o", + proffile + ".exe", + proffile + ".o", + "-L" + llfilinklib, + ] + + # Check whether we should use static or dynamic FI RT + execlist.extend(["-lllfi-rt"]) + execlist.extend(liblist) + retcode = execCompilation(execlist) + if retcode != 0: + print( + "...Error compiling with " + + os.path.basename(llvmgcc) + + ", trying with " + + os.path.basename(llvmgxx) + + "." + ) + execlist[0] = llvmgxx + retcode = execCompilation(execlist) + if retcode == 0: + execlist = [ + llvmgcc, + "-o", + fifile + ".exe", + fifile + ".o", + "-L" + llfilinklib, + ] + + # Check whether we should use static or dynamic FI RT + if options["useMLSpecificRT"]: + execlist.extend(["-lml-lltfi-rt"]) + else: + execlist.extend(["-lllfi-rt"]) + + execlist.extend(liblist) + retcode = execCompilation(execlist) + if retcode != 0: + print( + "...Error compiling with " + + os.path.basename(llvmgcc) + + ", trying " + + os.path.basename(llvmgxx) + + "." + ) + execlist[0] = llvmgxx + retcode = execCompilation(execlist) + + for tmpfile in tmpfiles: + try: + os.remove(tmpfile) + except Exception: + pass + if retcode != 0: + print( + "\nERROR: there was an error during linking and generating executables," + "Please take %s and %s and generate the executables manually (linking llfi-rt " + "in directory %s)." + % (proffile + _suffixOfIR(), fifile + _suffixOfIR(), llfilinklib), + file=sys.stderr, + ) + sys.exit(retcode) + else: + print("\nSuccess", file=sys.stderr) ################################################################################ def main(args): - parseArgs(args) - checkInputYaml() - readCompileOption() - compileProg() + parseArgs(args) + checkInputYaml() + readCompileOption() + compileProg() + ################################################################################ -if __name__=="__main__": - main(sys.argv[1:]) +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/bin/llfi-gui.py b/bin/llfi-gui.py index f69bbd74..f9934848 100755 --- a/bin/llfi-gui.py +++ b/bin/llfi-gui.py @@ -18,44 +18,50 @@ prog = os.path.basename(sys.argv[0]) script_path = os.path.dirname(os.path.realpath(__file__)) -sys.path.append((os.path.join(script_path, '../config'))) +sys.path.append((os.path.join(script_path, "../config"))) import java_paths -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def parseArgs(args): - global options - argid = 0 - while argid < len(args): - arg = args[argid] - if arg.startswith("-"): - if arg == "--help" or arg == "-h": - usage() + argid = 0 + while argid < len(args): + arg = args[argid] + if arg.startswith("-"): + if arg == "--help" or arg == "-h": + usage() + def startGUI(): - lib_path = os.path.join(script_path, os.pardir, 'gui/application/lib/*') - class_path = os.path.join(script_path, os.pardir, 'gui') - execlist = [java_paths.JAVA_EXECUTABLE, '-classpath', - java_paths.CMAKE_JAVA_INCLUDE_PATH+':'+lib_path+':'+class_path, - 'application.Main'] - print(' '.join(execlist)) - p = subprocess.Popen(execlist) - return + lib_path = os.path.join(script_path, os.pardir, "gui/application/lib/*") + class_path = os.path.join(script_path, os.pardir, "gui") + execlist = [ + java_paths.JAVA_EXECUTABLE, + "-classpath", + java_paths.CMAKE_JAVA_INCLUDE_PATH + ":" + lib_path + ":" + class_path, + "application.Main", + ] + print(" ".join(execlist)) + subprocess.Popen(execlist) + return + ################################################################################ def main(args): - parseArgs(args) - startGUI() + parseArgs(args) + startGUI() + ################################################################################ -if __name__=="__main__": - main(sys.argv[1:]) - +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/bin/profile.py b/bin/profile.py index bc86fd26..67021e8a 100755 --- a/bin/profile.py +++ b/bin/profile.py @@ -31,167 +31,174 @@ basedir = "" profiling_exe = "" -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) -def parseArgs(args): - global optionlist, profiling_exe, env - profiling_exe = os.path.realpath(args[0]) - basedir = os.path.abspath(os.path.dirname(os.path.dirname(profiling_exe))) - optionlist = args[1:] - - # remove the directory prefix for input files, this is to make it easier for the program - # to take a snapshot - for index, opt in enumerate(optionlist): - if os.path.isfile(opt): - if os.path.realpath(os.path.dirname(opt)) != basedir: - usage("File %s passed through option is not under current directory" % opt) - else: - optionlist[index] = os.path.basename(opt) - if basedir != os.getcwd(): - print("Change directory to:", basedir) - os.chdir(basedir) +def parseArgs(args): + global optionlist, profiling_exe + profiling_exe = os.path.realpath(args[0]) + basedir = os.path.abspath(os.path.dirname(os.path.dirname(profiling_exe))) + optionlist = args[1:] + + # remove the directory prefix for input files, this is to make it easier for the program + # to take a snapshot + for index, opt in enumerate(optionlist): + if os.path.isfile(opt): + if os.path.realpath(os.path.dirname(opt)) != basedir: + usage( + "File %s passed through option is not under current directory" % opt + ) + else: + optionlist[index] = os.path.basename(opt) + + if basedir != os.getcwd(): + print("Change directory to:", basedir) + os.chdir(basedir) def checkInputYaml(): - #Check for input.yaml's presence - yamldir = os.path.dirname(os.path.dirname(profiling_exe)) - try: - f = open(os.path.join(yamldir, 'input.yaml'), 'r') - except: - usage("No input.yaml file in the parent directory of profiling executable") - exit(1) - - #Check for input.yaml's correct formmating - try: - doc = yaml.safe_load(f) - f.close() - except: - usage("input.yaml is not formatted in proper YAML (reminder: use spaces, not tabs)") - exit(1) + # Check for input.yaml's presence + yamldir = os.path.dirname(os.path.dirname(profiling_exe)) + try: + with open(os.path.join(yamldir, "input.yaml"), "r") as f: + yaml.safe_load(f) + except OSError: + usage("No input.yaml file in the parent directory of profiling executable") + sys.exit(1) + except Exception: + usage( + "input.yaml is not formatted in proper YAML (reminder: use spaces, not tabs)" + ) + sys.exit(1) ################################################################################ def config(): - global inputdir, outputdir, baselinedir, errordir - # config - llfi_dir = os.path.dirname(profiling_exe) - - inputdir = os.path.join(llfi_dir, "prog_input") - outputdir = os.path.join(llfi_dir, "prog_output") - baselinedir = os.path.join(llfi_dir, "baseline") - errordir = os.path.join(llfi_dir, "error_output") - - if not os.path.isdir(outputdir): - os.mkdir(outputdir) - if not os.path.isdir(baselinedir): - os.mkdir(baselinedir) - if not os.path.isdir(errordir): - os.mkdir(errordir) - if not os.path.isdir(inputdir): - os.mkdir(inputdir) + global inputdir, outputdir, baselinedir, errordir + # config + llfi_dir = os.path.dirname(profiling_exe) + + inputdir = os.path.join(llfi_dir, "prog_input") + outputdir = os.path.join(llfi_dir, "prog_output") + baselinedir = os.path.join(llfi_dir, "baseline") + errordir = os.path.join(llfi_dir, "error_output") + + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + if not os.path.isdir(baselinedir): + os.mkdir(baselinedir) + if not os.path.isdir(errordir): + os.mkdir(errordir) + if not os.path.isdir(inputdir): + os.mkdir(inputdir) + ################################################################################ def execute(execlist): - #print "Begin" - #inputFile = open(inputfile, "r") - global outputfile - print('\t' + ' '.join(execlist)) - #get state of directory - dirSnapshot() - p = subprocess.Popen(execlist, stdout=subprocess.PIPE) - start_time = time.time() - - #communicate() blocks until program exit, but does not fill stdout and cause - (p_stdout,p_stderr) = p.communicate() #the target program to block - elapsetime = int(time.time() - start_time + 1) # round up to the nearest sec. - - #Continue with normal execution - moveOutput() - print("\t program finish", p.returncode) - print("\t time taken", elapsetime,"\n") - outputFile = open(outputfile, "wb") - outputFile.write(p_stdout) - outputFile.close() - replenishInput() #for cases where program deletes input or alters them each run - #inputFile.close() - return p.returncode + print("\t" + " ".join(execlist)) + # get state of directory + dirSnapshot() + p = subprocess.Popen(execlist, stdout=subprocess.PIPE) + start_time = time.time() + + # communicate() blocks until program exit, but does not fill stdout and cause + p_stdout, p_stderr = p.communicate() # the target program to block + elapsetime = int(time.time() - start_time + 1) # round up to the nearest sec. + + # Continue with normal execution + moveOutput() + print("\t program finish", p.returncode) + print("\t time taken", elapsetime, "\n") + with open(outputfile, "wb") as outputFile: + outputFile.write(p_stdout) + replenishInput() # for cases where program deletes input or alters them each run + # inputFile.close() + return p.returncode + ################################################################################ def storeInputFiles(): - global inputList - inputList=[] - ##========Consider comma as separator of arguments ================================== - temp_optionlist = [] - for item in optionlist: - if item.count(',') == 0: - temp_optionlist.append(item) - else: - temp_optionlist.extend(item.split(',')) - ##=================================================================================== - for opt in temp_optionlist: - if os.path.isfile(opt):#stores all files in inputList and copy over to inputdir - shutil.copy2(opt, os.path.join(inputdir, opt)) - inputList.append(opt) + global inputList + inputList = [] + ##========Consider comma as separator of arguments ================================== + temp_optionlist = [] + for item in optionlist: + if item.count(",") == 0: + temp_optionlist.append(item) + else: + temp_optionlist.extend(item.split(",")) + ##=================================================================================== + for opt in temp_optionlist: + if os.path.isfile( + opt + ): # stores all files in inputList and copy over to inputdir + shutil.copy2(opt, os.path.join(inputdir, opt)) + inputList.append(opt) + ################################################################################ -def replenishInput():#TODO make condition to skip this if input is present - for each in inputList: - if not os.path.isfile(each):#copy deleted inputfiles back to basedir - shutil.copy2(os.path.join(inputdir, each), each) +def replenishInput(): # TODO make condition to skip this if input is present + for each in inputList: + if not os.path.isfile(each): # copy deleted inputfiles back to basedir + shutil.copy2(os.path.join(inputdir, each), each) + ################################################################################ def moveOutput(): - #move all newly created files that are not "llfi.stat.prof.txt" < -- since this is a product of profiling - newfiles = [_file for _file in os.listdir(".")] - for each in newfiles: - if each not in dirBefore and each != "llfi.stat.prof.txt": - fileSize = os.stat(each).st_size - if fileSize == 0 and each.startswith("llfi"): - #empty library output, can delete - print(each + " is going to be deleted for having size of " + str(fileSize)) - os.remove(each) - else: - flds = each.split(".") - newName = '.'.join(flds[0:-1]) - newName+='.prof.'+flds[-1] - os.rename(each, os.path.join(baselinedir, newName)) + # move all newly created files that are not "llfi.stat.prof.txt" < -- since this is a product of profiling + newfiles = [_file for _file in os.listdir(".")] + for each in newfiles: + if each not in dirBefore and each != "llfi.stat.prof.txt": + fileSize = os.stat(each).st_size + if fileSize == 0 and each.startswith("llfi"): + # empty library output, can delete + print( + each + " is going to be deleted for having size of " + str(fileSize) + ) + os.remove(each) + else: + flds = each.split(".") + newName = ".".join(flds[0:-1]) + newName += ".prof." + flds[-1] + os.rename(each, os.path.join(baselinedir, newName)) + ################################################################################ def dirSnapshot(): - #snapshot of directory before each execute() is performed - global dirBefore - dirBefore = [_file for _file in os.listdir(".")] + # snapshot of directory before each execute() is performed + global dirBefore + dirBefore = [_file for _file in os.listdir(".")] + ################################################################################ def main(args): - global optionlist, outputfile + global outputfile - parseArgs(args) - checkInputYaml() - config() + parseArgs(args) + checkInputYaml() + config() - storeInputFiles() - # baseline - outputfile = os.path.join(baselinedir, "golden_std_output") - execlist = [profiling_exe] - execlist.extend(optionlist) + storeInputFiles() + # baseline + outputfile = os.path.join(baselinedir, "golden_std_output") + execlist = [profiling_exe] + execlist.extend(optionlist) - return execute(execlist) + return execute(execlist) ################################################################################ -if __name__=="__main__": - if len(sys.argv) == 1: - usage('Must provide the profiling executable and its options') - exit(1) - exit(main(sys.argv[1:])) +if __name__ == "__main__": + if len(sys.argv) == 1: + usage("Must provide the profiling executable and its options") + sys.exit(1) + sys.exit(main(sys.argv[1:])) diff --git a/caveats.txt b/caveats.txt index f5bbcdf5..86c837df 100755 --- a/caveats.txt +++ b/caveats.txt @@ -1,17 +1,36 @@ -GOTCHAS: - 1. For injecting fault n times, use the same IR file (i.e. run the script once, while calling the fault injector with the same IR file n times). - - 2. The classification of the injection results depends on the comparison of fault-free execution and fault-injected execution. - That means non-deterministic programs may not work well in classification. - - 3. For different test benches, the method used to classify the results of faulty executions might be different. Please write your own specific classification code. - - -KNOWN PROBLEMS: - 1. On 32 bit systems, llvm-gcc 4.2.1 might not be compatible with GCC other than version 4.4.5. - Runing llvm-gcc 4.2.1 on Ubuntu 12.04 with GCC-4.6 has failed on our test computers. - 3. On Default 64bit installations of Debian "Wheezy", running LLFI will fail on our test computers. - -Recommended Environment: - Debian 6.0.7 "Squeeze" 64bit Default Installation - +GOTCHAS: + 1. For injecting fault n times, use the same IR file (i.e. run the script once, + while calling the fault injector with the same IR file n times). + + 2. The classification of the injection results depends on the comparison of + fault-free execution and fault-injected execution. Non-deterministic programs + may not work well in classification. + + 3. For different test benches, the method used to classify the results of faulty + executions might be different. Write your own specific classification code + where necessary. + + 4. Software fault modes that target memmove/memcpy (e.g. WrongDestination(Data)) + do NOT work when the compiler optimises those calls to LLVM intrinsics + (e.g. @llvm.memmove.p0.p0.i64). The intrinsic has no injectable register + arguments at runtime. Use a fault mode targeting a regular C library call + instead (e.g. WrongPointer(Data) targets fread/fwrite). + +KNOWN PROBLEMS: + 1. ML-related tools (compiletoIR.py for ML models) + require additional dependencies (TensorFlow, PyTorch, ONNX-MLIR) that are + not part of the base installation. See README.md for installation instructions. + + 2. The --all_ml test flag runs ML/ONNX tests that require optional Python + packages (tensorflow, tf2onnx, torch, onnx, pygraphviz, pydot) and the + onnx-mlir compiler. Tests skip gracefully when dependencies are absent. + The ONNX-to-IR and fault injection tiers additionally require model.ll to + be pre-built by running compile.sh in sample_programs/ml_sample_programs/ + vision_models/mnist/. + + 3. On Ubuntu systems where LLVM is installed via apt, clang is available as + clang-20 (not clang) and will not be found automatically by ./setup. Pass + -LLVM_GXX_BIN_DIR /usr/lib/llvm-20/bin explicitly in that case. + +Recommended Environment: + 64-bit Linux (Ubuntu 20.04 or later), LLVM 20.x diff --git a/docker/Dockerfile b/docker/Dockerfile index 0cee6419..d6007291 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,6 +6,14 @@ ARG NPROC=2 RUN apt-get update RUN apt-get -y upgrade +# Install prerequisites +RUN apt-get update && apt-get install -y wget + +# Install specific version (3.20) of CMake +RUN wget https://github.com/Kitware/CMake/releases/download/v3.26.0/cmake-3.26.0-linux-x86_64.sh \ + && sh cmake-3.26.0-linux-x86_64.sh --prefix=/usr/local --exclude-subdir \ + && rm cmake-3.26.0-linux-x86_64.sh + # Install tools needed RUN distro=$(cat /etc/os-release|grep -Po '(?<=^ID=").*(?=")|(?<=^ID=)[^"].*[^"]') \ && TZ="America/Vancouver" \ @@ -16,7 +24,7 @@ RUN distro=$(cat /etc/os-release|grep -Po '(?<=^ID=").*(?=")|(?<=^ID=)[^"].*[^"] ln -sf /usr/share/zoneinfo/${TZ} /etc/localtime && \ dpkg-reconfigure -f noninteractive tzdata && \ apt-get install -qq -y --no-install-recommends \ - autoconf automake ca-certificates cmake curl \ + autoconf automake ca-certificates curl \ default-jdk-headless gcc g++ git libncurses-dev \ libtool make maven ninja-build openjdk-11-jdk-headless \ python3 python3-dev python3-distutils python3-numpy \ @@ -31,7 +39,7 @@ RUN distro=$(cat /etc/os-release|grep -Po '(?<=^ID=").*(?=")|(?<=^ID=)[^"].*[^"] https://dl.fedoraproject.org/pub/epel/epel-release-latest-8.noarch.rpm && \ yum update -q -y && \ yum install -q -y \ - autoconf automake ca-certificates cmake diffutils \ + autoconf automake ca-certificates diffutils \ file java-11-openjdk-devel java-11-openjdk-headless \ gcc gcc-c++ git libtool make ncurses-devel \ python39 python39-devel python39-numpy python39-pip \ @@ -51,13 +59,13 @@ RUN distro=$(cat /etc/os-release|grep -Po '(?<=^ID=").*(?=")|(?<=^ID=)[^"].*[^"] RUN pip install tensorflow RUN pip install tf2onnx -RUN pip3 install pyyaml===5.4.1 +RUN pip3 install pyyaml==5.4.1 WORKDIR /home/ ### LLVM RUN git clone https://github.com/llvm/llvm-project.git && \ - cd llvm-project && git checkout 9778ec057cf4 && cd .. && \ + cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd .. && \ mkdir llvm-project/build && cd llvm-project/build && \ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_PROJECTS="clang;mlir" \ @@ -71,7 +79,6 @@ RUN git clone https://github.com/llvm/llvm-project.git && \ RUN apt-get update RUN apt-get install unzip -RUN apt-get install -y wget ### libprotoc RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.17.2/protobuf-all-3.17.2.zip @@ -85,8 +92,6 @@ RUN make install RUN ldconfig WORKDIR /home/ -RUN git clone https://github.com/DependableSystemsLab/LLTFI.git - ### ONNX_MLIR RUN git clone --recursive https://github.com/DependableSystemsLab/onnx-mlir-lltfi.git && \ @@ -94,6 +99,7 @@ RUN git clone --recursive https://github.com/DependableSystemsLab/onnx-mlir-lltf mv onnx-mlir-lltfi onnx-mlir && \ cd onnx-mlir && \ git checkout LLTFI && \ + git submodule update --init --recursive && \ cd .. && \ mkdir onnx-mlir/build && cd onnx-mlir/build && \ cmake -G Ninja \ @@ -103,6 +109,7 @@ RUN git clone --recursive https://github.com/DependableSystemsLab/onnx-mlir-lltf cmake --build . -j${NPROC} && \ ninja install +RUN git clone https://github.com/DependableSystemsLab/LLTFI.git WORKDIR /home/LLTFI ### LLTFI diff --git a/docker/README.md b/docker/README.md index 60e5f5f8..207ff9e7 100755 --- a/docker/README.md +++ b/docker/README.md @@ -2,7 +2,7 @@ `docker/Dockerfile` can be used to build and run LLTFI in a docker container. - Use the variable `NPROC` to specify the number of threads that should run simultaneously. -- All dependencies required to build LLTFI and run sample programs and tests will be installed in the docker container built using this docker file. However, it does not install dependencies required to run the Web-App GUI. +- All dependencies required to build LLTFI and run sample programs and tests will be installed in the docker container built using this docker file. Steps to build: 1. **Creating a docker image from the Dockerfile:** Copy the Dockerfile to a directory of your choice outside this repository. To create an image, run the command `docker build --tag imageName .` in the terminal. diff --git a/docs/adding_a_test.md b/docs/adding_a_test.md new file mode 100644 index 00000000..ba0bf4e8 --- /dev/null +++ b/docs/adding_a_test.md @@ -0,0 +1,303 @@ +# Adding a Test Case to the LLTFI Test Suite + +This document explains the structure of the regression test suite and how to +add a new test. Read it when CONTRIBUTING.md says "new functionality has a +corresponding test case" and you are not sure where to start. + +--- + +## Overview of the test harness + +The test suite lives in `test_suite/` (source tree) and is copied into +`/test_suite/` during installation. The driver script is: + +``` +test_suite/SCRIPTS/llfi_test +``` + +Tests are split into categories invoked with flags such as `--all_hardware_faults`, +`--all_trace_tools_tests`, etc. The category for most new fault injection tests +is **HardwareFaults**. + +--- + +## Directory layout + +``` +test_suite/ + PROGRAMS/ Compiled IR files (.ll / .bc) and input data + factorial/ One sub-directory per program + factorial.c + Makefile + mcf/ + ... + Makefile Builds all PROGRAMS/* in one pass + Makefile.common Shared clang flags, included by every program Makefile + + HardwareFaults/ One sub-directory per hardware-fault test case + random/ + input.yaml Controls instrumentation and injection for this test + (MCF.ll, inp.in) Deployed here at test time from PROGRAMS/ + funcname/ + ... + + test_suite.yaml Master registry — lists programs, their files, and + which test case uses which program + + SCRIPTS/ + llfi_test Top-level driver + build_prog.py Builds programs in PROGRAMS/ + deploy_prog.py Copies program files into test case directories + inject_prog.py Runs instrument → profile → injectfault for each case + check_injection.py Verifies the llfi/ output directory is well-formed + test_trace_tools.py Trace-tool specific tests + test_instruction_duplication.py SID pass tests + test_ml_models.py ML fault injection tests + test_ml_tools.py ML tool unit tests +``` + +--- + +## How a fault injection test runs + +When `llfi_test --all_hardware_faults` runs, it calls these scripts in order: + +1. **`build_prog.py`** — builds all programs listed in `test_suite.yaml` under + `PROGRAMS:` by running `make` in each program's directory. + +2. **`deploy_prog.py`** — copies the compiled files (e.g. `MCF.ll`, `inp.in`) + from `PROGRAMS/mcf/` into each test case directory (e.g. `HardwareFaults/random/`) + that uses that program. + +3. **`inject_prog.py`** — for each test case directory: + - `cd` into the directory + - runs `instrument` on the `.ll` file + - runs `profile` + - runs `injectfault` + - logs stdout/stderr to `llfi.test.log.{instrument,profile,injectFault}.txt` + +4. **`check_injection.py`** — verifies that each test case directory now contains + a well-formed `llfi/` tree (subdirectories present, at least one stat file). + Reports PASS or FAIL. + +--- + +## Adding a HardwareFaults test + +### Step 1 — Decide which program to use + +Check whether an existing program in `PROGRAMS/` suits the new test. Reusing +a program is strongly preferred. If you need a new program, follow Step 1a. + +**Step 1a (if a new program is needed) — Add a program** + +Create a subdirectory under `test_suite/PROGRAMS/`: + +``` +test_suite/PROGRAMS/myprogram/ + myprogram.c (or myprogram.ll if you hand-write IR) + Makefile + myinput.txt (input data files, if needed) +``` + +The `Makefile` must produce a `.ll` file. Use an existing program's Makefile +as a template: + +```makefile +TARGET = myprogram +include ../Makefile.common + +SRC_FILES = $(wildcard *.c) +OBJECTS = $(SRC_FILES:.c=.bc) +LL_FILE = $(TARGET).ll + +default: all +all: $(LL_FILE) + +%.ll: %.bc + $(LLVMDIS) $< -o $@ + +%.bc: %.c + $(LLVMGCC) $(COMPILE_FLAGS) $< -c -o $@ + +clean: + $(RM) -f *.bc *.ll +``` + +Then register the program and its files in `test_suite.yaml`: + +```yaml +PROGRAMS: + myprogram: + - myprogram.ll + - myinput.txt # list every file that needs to be deployed +``` + +And add the program input under `INPUTS:`: + +```yaml +INPUTS: + myprogram: myinput.txt # passed as argv to the program; omit if none +``` + +### Step 2 — Create the test case directory + +```bash +mkdir test_suite/HardwareFaults/my_new_test +``` + +Write an `input.yaml` in that directory. Example for a hardware fault test: + +```yaml +defaultTimeOut: 500 + +compileOption: + instSelMethod: + - insttype: + include: + - all + exclude: + - ret + + regSelMethod: regloc + regloc: dstreg + +runOption: + - run: + numOfRuns: 5 + fi_type: bitflip +``` + +See `docs/input_yaml_guide.md` for all available keys. + +### Step 3 — Register the test case in `test_suite.yaml` + +Add an entry under `HardwareFaults:`: + +```yaml +HardwareFaults: + my_new_test: myprogram # value is the PROGRAMS key +``` + +The value tells `deploy_prog.py` which program's files to copy into the test +case directory before running injection. + +### Step 4 — Run the test locally + +From the **build** directory: + +```bash +cd /test_suite + +# Rebuild programs (only needed once, or after source changes) +python3 SCRIPTS/build_prog.py + +# Run only the new test +python3 SCRIPTS/llfi_test --test_cases HardwareFaults/my_new_test +``` + +A PASS result means the harness found a well-formed `llfi/` directory after +injection. Check that the stat files are non-empty and that the output makes +sense for your test case. + +### Step 5 — Verify it does not break the full suite + +```bash +python3 SCRIPTS/llfi_test --all_cpp +``` + +Expected: 21/21 (or 22/22 if your test adds to the count) PASS. + +--- + +## Adding a more specific test (custom Python script) + +For tests that check something beyond "did injection produce a well-formed +`llfi/` directory" — e.g., counting the number of injected faults, verifying +a specific stat value, or testing a tool that is not part of the injection +pipeline — write a standalone Python test script in `test_suite/SCRIPTS/`. + +Look at `test_trace_tools.py` for examples of +the pattern: + +```python +#!/usr/bin/env python3 +""" +One-line description. + +Usage: python3 SCRIPTS/test_my_feature.py +""" + +import os +import sys + +PASS = "PASS" +FAIL = "FAIL" +SKIP = "SKIP" + +def test_something(): + # ... perform the check ... + if condition: + return PASS, "explanation" + else: + return FAIL, "what went wrong" + +def main(): + results = [] + results.append(("my_check_name", test_something())) + + passed = failed = skipped = 0 + for name, (status, msg) in results: + print(f"{name}: {status}" + (f" — {msg}" if msg else "")) + if status == PASS: + passed += 1 + elif status == FAIL: + failed += 1 + else: + skipped += 1 + + print(f"\n{passed} PASS, {failed} FAIL, {skipped} SKIP") + sys.exit(0 if failed == 0 else 1) + +if __name__ == "__main__": + main() +``` + +### SKIP convention + +Return `SKIP` (not `FAIL`) when a test requires an optional dependency that +is not installed. This keeps the suite green on machines that do not have +TensorFlow, onnx-mlir, etc. Always print a short message explaining what is +missing: + +```python +try: + import onnx +except ImportError: + return SKIP, "onnx not installed — pip install onnx" +``` + +### Registering a custom script in `llfi_test` + +If the script should run as part of `--all` or a specific `--all_*` flag, add +a call to it inside `llfi_test`. Search for an existing call (e.g. to +`test_trace_tools.py`) to see the pattern: + +```python +# In llfi_test, inside the relevant branch: +r = subprocess.call([sys.executable, + os.path.join(script_dir, "test_my_feature.py")]) +``` + +If it is an optional test (ML, ONNX), add it under the `--all_ml` branch. + +--- + +## Checklist before submitting + +- [ ] New program (if any) has a `Makefile` that produces a `.ll` file +- [ ] Program and its input files are registered in `test_suite.yaml` +- [ ] Test case directory exists with a valid `input.yaml` +- [ ] Test case is registered under the correct category in `test_suite.yaml` +- [ ] `python3 SCRIPTS/llfi_test --test_cases ` reports PASS +- [ ] `python3 SCRIPTS/llfi_test --all` still reports 21/21 (or N/N) PASS diff --git a/docs/input_masterlist.yaml b/docs/input_masterlist.yaml index ca4a7e6f..9f945ea0 100644 --- a/docs/input_masterlist.yaml +++ b/docs/input_masterlist.yaml @@ -40,15 +40,6 @@ compileOption: - dual_feasible - main -## To use a specific instruction selector to select targeting instructions: -## ('BufferOverflow(API)' is the name of a software failure instruction -## selector shipped with LLFI) -compileOption: - instSelMethod: - - customInstselector: - include: - - BufferOverflow(API) - ## To use a location-based register selector to select targeting register: regSelMethod: regloc ## To select the destination register @@ -60,14 +51,6 @@ compileOption: ## To select the 3rd source register regloc: srcreg3 - - ## To use a custom register selector to select targeting register: - ## ('BufferOverflow(API)' is the name of a software failure register - ## selector shipped with LLFI) - regSelMethod: customregselector - customRegSelector: BufferOverflow(API) - - ## To incorporate the forward slice/backward slice of the selected instructions as injection targets: includeInjectionTrace: - forward # include forward trace of the selected instructions into fault injection targets @@ -115,13 +98,6 @@ runOption: fi_type: bitflip window_len: 10 - ## To use a custom fault injector (fault type) for this experiment: - ## ('BufferOverflow(API)' is an fault injector for software failures - ## shipped with LLFI) - - run: - numOfRuns: 5 - fi_type: BufferOverflow(API) - ## To inject a maximum of 4 bit-flip errors into multiple registers. ## The distance between each consecutive injection is controlled by a random number representing the number of dynamic instructions that needs ## to be executed between the consecutive injections. The lower bound and upper bound of the random number are especified as two of the run time options. diff --git a/docs/input_masterlist_ml.yaml b/docs/input_masterlist_ml.yaml index 24a40bfb..b06f2efb 100644 --- a/docs/input_masterlist_ml.yaml +++ b/docs/input_masterlist_ml.yaml @@ -103,13 +103,6 @@ runOption: fi_type: bitflip window_len: 10 - ## To use a custom fault injector (fault type) for this experiment: - ## ('BufferOverflow(API)' is an fault injector for software failures - ## shipped with LLFI) - - run: - numOfRuns: 5 - fi_type: BufferOverflow(API) - ## To inject a maximum of 4 bit-flip errors into multiple registers. ## The distance between each consecutive injection is controlled by a random number representing the number of dynamic instructions that needs ## to be executed between the consecutive injections. The lower bound and upper bound of the random number are especified as two of the run time options. diff --git a/docs/input_yaml_guide.md b/docs/input_yaml_guide.md new file mode 100644 index 00000000..1a7cf99c --- /dev/null +++ b/docs/input_yaml_guide.md @@ -0,0 +1,494 @@ +# LLTFI `input.yaml` Guide + +Every program you run through LLTFI needs an `input.yaml` file in its working +directory. This file controls which instructions are targeted for fault +injection (compile time) and how many runs are performed with which fault type +(run time). + +The reference schemas for all keys are in: + +- `docs/input_masterlist.yaml` — C/C++ programs +- `docs/input_masterlist_ml.yaml` — ML programs (ONNX-MLIR compiled) + +This guide explains the keys in prose and provides annotated examples. +For a complete end-to-end walkthrough see `docs/tutorial_first_experiment.md` +(C/C++ programs) or `docs/tutorial_ml_experiment.md` (ML/ONNX models). + + +## Top-level structure + +```yaml +defaultTimeOut: # optional; default 500 s + +kernelOption: # optional + - forceRun + +compileOption: # required + instSelMethod: ... + regSelMethod: ... + # optional tracing settings + +runOption: # required; list of one or more run blocks + - run: + ... +``` + +--- + +## `defaultTimeOut` + +Wall-clock timeout (seconds) applied to every fault injection run. Individual +`run` blocks may override this with their own `timeOut` key. + +```yaml +defaultTimeOut: 1000 +``` + +--- + +## `kernelOption` + +Optional. Currently only one value is meaningful: + +| Value | Effect | +|-------|--------| +| `forceRun` | Run the program even if profiling detects zero injectable instructions. Useful when the target kernel is very short. | + +```yaml +kernelOption: + - forceRun +``` + +--- + +## `compileOption` + +Controls how `instrument.py` selects instructions and registers to mark as +fault injection targets. + +### `instSelMethod` + +A list with **exactly one** selector entry. Three selector kinds are +available. + +#### `insttype` — select by LLVM IR instruction type + +```yaml +instSelMethod: + - insttype: + include: + - all # special keyword: every instruction type + exclude: + - ret # always recommended — injecting into ret corrupts the stack + - alloca + - call +``` + +`include` and `exclude` take LLVM IR instruction names (lower-case mnemonic as +they appear in `.ll` files). Common values: + +| Name | Instruction | +|------|-------------| +| `add`, `sub`, `mul` | Integer arithmetic | +| `fadd`, `fsub`, `fmul`, `fdiv` | Floating-point arithmetic | +| `load`, `store` | Memory access | +| `getelementptr` | Address calculation | +| `icmp`, `fcmp` | Comparison | +| `call` | Function call | +| `ret` | Return (avoid injecting here) | +| `phi` | PHI node | +| `alloca` | Stack allocation (avoid) | + +The full list is the LLVM Language Reference: https://llvm.org/docs/LangRef.html + +#### `funcname` — select by containing function name + +Only instruments instructions that are (or are not) inside a named function. +Combine with `insttype` semantics: `all` in `include` means every function. + +```yaml +instSelMethod: + - funcname: + include: + - all + exclude: + - main + - helper_init +``` + +#### `customInstselector` — use a named selector plugin + +Used for ML layer targeting. The `include` list names +the selector class; `options` passes arguments to it. + +For ML programs (see [ML selectors](#ml-programs-customtensoroperator) below): + +```yaml +instSelMethod: + - customInstselector: + include: + - CustomTensorOperator + options: + - -layerNo=0 + - -layerName=conv +``` + +For pinning to specific LLFI indices (useful for reproducing a previous result): + +```yaml +instSelMethod: + - customInstselector: + include: + - llfiindex + options: + - -injecttoindex=2293 + - -injecttoindex=568 +``` + +--- + +### `regSelMethod` + +Selects which register within each targeted instruction to corrupt. + +#### `regloc` — location-based selection + +```yaml +regSelMethod: regloc +regloc: dstreg # destination register (output of the instruction) +``` + +| `regloc` value | Meaning | +|----------------|---------| +| `dstreg` | Destination (output) register | +| `srcreg1` | First source register | +| `srcreg2` | Second source register | +| `srcreg3` | Third source register | +| `allreg` | All registers of the instruction | + +`dstreg` is the most common choice for hardware fault experiments. `allreg` +increases the injection surface. + +--- + +### `includeInjectionTrace` (optional) + +Expands the injection target set to include data-flow dependents of the +originally selected instructions. + +```yaml +includeInjectionTrace: + - forward # instructions that consume the selected instruction's output + - backward # instructions whose output feeds into the selected instruction +``` + +Including the trace increases fault coverage at the cost of a larger injection +set (more possible injection points, longer profiling). + +--- + +### `tracingPropagation` (optional) + +Enables dynamic value tracing during fault injection runs, writing trace files +to `llfi/llfi_stat_output/`. + +```yaml +tracingPropagation: True + +tracingPropagationOption: + maxTrace: 250 # max instructions recorded per run + debugTrace: False # print trace to stderr during run + generateCDFG: True # write a dot-format control/data-flow graph + mlTrace: False # use ML-aware trace format (ML programs only) +``` + +Tracing adds overhead. Disable it for large experiments where you only need +the pass/fail outcome. + +--- + +## `runOption` + +A list of one or more `run` blocks. Each block is an independent experiment +that runs after the single profiling pass. All blocks use the same +instrumented binary. + +### Common keys + +| Key | Type | Meaning | +|-----|------|---------| +| `numOfRuns` | int | Number of fault injection trials | +| `fi_type` | string | Fault type (see below) | +| `timeOut` | int | Per-run timeout in seconds; overrides `defaultTimeOut` | +| `verbose` | bool | Print return-code summary after each run | + +### `fi_type` values + +**Hardware faults:** + +| Value | Effect | +|-------|--------| +| `bitflip` | Flip a randomly chosen bit in the register | +| `stuck_at_0` | Force all bits to 0 | +| `stuck_at_1` | Force all bits to 1 | + +**Auto-injection** — let the runtime choose the injector: + +```yaml +fi_type: AutoInjection +``` + +--- + +### Pinning a fault (reproducibility) + +To reproduce a specific previous injection, pin the exact cycle, register, and +bit. All four keys must be present together: + +```yaml +- run: + numOfRuns: 1 + fi_type: bitflip + fi_cycle: 684347 # dynamic instruction count at injection point + fi_index: 417 # LLFI index of the targeted instruction + fi_reg_index: 0 # which register of that instruction (0-based) + fi_bit: 15 # which bit to flip (0-based from LSB) +``` + +Additional pin keys (less commonly needed): + +| Key | Meaning | +|-----|---------| +| `fi_reg` | Raw register identifier (internal) | +| `fi_reg_pos` | Position within a multi-word register | + +--- + +### Multiple-bit faults + +Flip more than one bit within a single register: + +```yaml +- run: + numOfRuns: 5 + fi_type: bitflip + fi_num_bits: 4 # flip 4 randomly chosen bits in one register +``` + +--- + +### Two-fault experiments (`window_len`) + +Inject into two different registers with a bounded gap between them: + +```yaml +- run: + numOfRuns: 50 + fi_type: bitflip + window_len: 10 # max dynamic instructions between the two injections +``` + +--- + +### Multiple faults across registers (`fi_max_multiple`) + +Inject up to N faults into separate registers, with a random spacing drawn from +`[window_len_multiple_startindex, window_len_multiple_endindex]` dynamic +instructions between consecutive injections: + +```yaml +- run: + numOfRuns: 5 + fi_type: bitflip + fi_max_multiple: 4 + window_len_multiple_startindex: 10 + window_len_multiple_endindex: 100 +``` + +These keys can be combined with pin keys to anchor the first injection and then +spread subsequent ones randomly. + +--- + +## ML programs: `CustomTensorOperator` + +When injecting into an ONNX-MLIR compiled model, use `CustomTensorOperator` as +the instruction selector. This selector understands the model's layer +structure. + +```yaml +instSelMethod: + - customInstselector: + include: + - CustomTensorOperator + options: + - -layerNo= + - -layerName= +``` + +### `layerNo` + +| Value | Meaning | +|-------|---------| +| `0` | All layers of the given type | +| `1` | First layer of the given type | +| `2` | Second layer, and so on | + +### `layerName` + +Valid layer type names: + +`conv`, `relu`, `maxpool`, `matmul`, `add`, `avgpool`, `loop`, +`nonmaxs`, `unsqueeze`, `softmax`, `all` + +Use `all` with `layerNo=0` to target every instruction in the model. + +### Targeting multiple layer types + +Separate multiple entries with `;`: + +```yaml +options: + - -layerNo=0;0;0 + - -layerName=conv;relu;matmul +``` + +The `layerNo` and `layerName` lists must have the same length. + +--- + +## Complete examples + +### Minimal hardware fault experiment + +```yaml +compileOption: + instSelMethod: + - insttype: + include: + - all + exclude: + - ret + + regSelMethod: regloc + regloc: dstreg + +runOption: + - run: + numOfRuns: 100 + fi_type: bitflip +``` + +### Hardware fault with tracing and multiple run configurations + +```yaml +defaultTimeOut: 1000 + +compileOption: + instSelMethod: + - insttype: + include: + - fadd + - fmul + - fdiv + exclude: + - ret + + regSelMethod: regloc + regloc: dstreg + + includeInjectionTrace: + - forward + + tracingPropagation: True + tracingPropagationOption: + maxTrace: 250 + debugTrace: False + generateCDFG: False + +runOption: + - run: + numOfRuns: 100 + fi_type: bitflip + + - run: + numOfRuns: 50 + fi_type: stuck_at_0 + + - run: + numOfRuns: 50 + fi_type: stuck_at_1 +``` + +### ML model — all convolutional layers, multiple faults per run + +```yaml +defaultTimeOut: 5000 + +compileOption: + instSelMethod: + - customInstselector: + include: + - CustomTensorOperator + options: + - -layerNo=0 + - -layerName=conv + + regSelMethod: regloc + regloc: dstreg + + includeInjectionTrace: + - forward + + tracingPropagation: False + tracingPropagationOption: + maxTrace: 250 + debugTrace: False + mlTrace: False + generateCDFG: False + +runOption: + - run: + numOfRuns: 1000 + fi_type: bitflip + fi_max_multiple: 2 + window_len_multiple_startindex: 1 + window_len_multiple_endindex: 500 +``` + +### ML model — all layer types + +```yaml +defaultTimeOut: 5000 + +compileOption: + instSelMethod: + - customInstselector: + include: + - CustomTensorOperator + options: + - -layerNo=0;0;0;0;0;0;0 + - -layerName=conv;relu;matmul;maxpool;add;avgpool;softmax + + regSelMethod: regloc + regloc: dstreg + + includeInjectionTrace: + - forward + + tracingPropagation: False + tracingPropagationOption: + maxTrace: 250 + debugTrace: False + mlTrace: False + generateCDFG: True + +runOption: + - run: + numOfRuns: 1000 + fi_type: bitflip + fi_max_multiple: 2 + window_len_multiple_startindex: 1 + window_len_multiple_endindex: 500 + timeOut: 5000 +``` diff --git a/docs/tutorial_first_experiment.md b/docs/tutorial_first_experiment.md new file mode 100644 index 00000000..40072470 --- /dev/null +++ b/docs/tutorial_first_experiment.md @@ -0,0 +1,292 @@ +# Tutorial: Your First Fault Injection Experiment + +This tutorial walks you through a complete fault injection experiment using the +`factorial` sample program. By the end you will have: + +- compiled a C program to LLVM IR +- instrumented it with LLTFI +- run the profiling pass +- run fault injection +- interpreted every output file + +**Prerequisites:** LLTFI is installed and `LLFI_BUILD_ROOT` points at the build +directory (see `README.md` for setup). + +--- + +## 1. Set up environment variables + +```bash +export LLFI_BUILD_ROOT=/path/to/LLTFI-build +export PATH=/usr/lib/llvm-20/bin:$PATH # so clang, opt, llvm-dis are found +``` + +--- + +## 2. Copy the sample program to a working directory + +```bash +cp -r $LLFI_BUILD_ROOT/../sample_programs/cpp_sample_programs/factorial /tmp/factorial +cd /tmp/factorial +``` + +The directory contains: + +``` +factorial.c Source file +compileAndRun.sh Convenience script (wraps the three steps below) +input.yaml Fault injection configuration +``` + +--- + +## 3. Compile to LLVM IR + +LLTFI works on LLVM bitcode (`.bc`) or human-readable IR (`.ll`). Compile +with `-emit-llvm`: + +```bash +clang -emit-llvm -g -S factorial.c -o factorial.ll +``` + +The `-g` flag adds debug information, which improves LLFI's index maps. + +--- + +## 4. Examine `input.yaml` + +```yaml +compileOption: + instSelMethod: + - insttype: + include: + - all + exclude: + - ret + + regSelMethod: regloc + regloc: allreg + + tracingPropagation: False + + tracingPropagationOption: + maxTrace: 250 + debugTrace: False + generateCDFG: True + +runOption: + - run: + numOfRuns: 5 + fi_type: bitflip +``` + +This says: + +- **Target**: every instruction type except `ret` (returning mid-function would + corrupt the stack, giving meaningless results) +- **Register**: all registers of each targeted instruction (`allreg`) +- **Tracing**: off (keeps output small for this tutorial) +- **Experiment**: 5 fault injection runs, each flipping a randomly chosen bit + in a randomly chosen targeted register + +For the full key reference, see `docs/input_yaml_guide.md`. + +--- + +## 5. Instrument + +```bash +$LLFI_BUILD_ROOT/bin/instrument --readable factorial.ll +``` + +This runs the LLVM pass pipeline on `factorial.ll` and produces the `llfi/` +directory: + +``` +llfi/ + factorial-profiling.exe Binary for the profiling pass + factorial-faultinjection.exe Binary for fault injection + factorial-profiling.ll Instrumented IR (profiling version) + factorial-faultinjection.ll Instrumented IR (fault injection version) + factorial-llfi_index.ll IR annotated with LLFI index numbers +``` + +Two top-level files are also written: + +| File | Contents | +|------|----------| +| `llfi.log.compilation.txt` | Full output of the instrumentation pass (errors appear here) | +| `llfi.config.compiletime.txt` | Summary of what the pass selected (failure class, mode, targets) | + +Check `llfi.config.compiletime.txt` to verify the selector matched what you +expected: + +``` +failure_class=HardwareFault +failure_mode=SpecifiedInstructionTypes +targets= +injector= +``` + +--- + +## 6. Profile + +```bash +$LLFI_BUILD_ROOT/bin/profile ./llfi/factorial-profiling.exe 6 +``` + +The argument `6` is the program input (compute 6! = 720). + +Profiling runs the program once without injecting any faults. It produces: + +| File | Contents | +|------|----------| +| `llfi.stat.prof.txt` | Total dynamic instruction count: `total_cycle=` | +| `llfi.stat.totalindex.txt` | Number of unique injectable instruction indices: `totalindex=` | +| `llfi/baseline/golden_std_output` | Stdout of the fault-free run — the reference for SDC detection | +| `llfi/baseline/mcf.prof.out` | Program-specific profiling output (if any) | + +For factorial with input 6, `totalindex` will be around 50–100 (a small +program) and `golden_std_output` will contain `720`. + +--- + +## 7. Inject faults + +```bash +$LLFI_BUILD_ROOT/bin/injectfault ./llfi/factorial-faultinjection.exe 6 +``` + +This runs the program `numOfRuns` times (5 in our `input.yaml`), injecting one +fault per run. Each run: + +1. Reads `llfi.config.runtime.txt` to decide where and what to inject. +2. Selects a random injectable instruction and a random bit. +3. Flips that bit at the moment the instruction executes. +4. Records the outcome. + +On completion, the `llfi/` directory is populated: + +``` +llfi/ + baseline/ + golden_std_output Reference output (from profiling) + std_output/ + std_outputfile-run-0-0 Stdout of run 0 (first experiment, trial 0) + std_outputfile-run-0-1 Stdout of run 1 + ... + error_output/ + errorfile-run-0-N Written only for runs that crashed or hung + llfi_stat_output/ + llfi.stat.fi.injectedfaults.0-0.txt Injection details for run 0 + llfi.stat.fi.injectedfaults.0-1.txt Injection details for run 1 + ... + prog_output/ + (disk output from the program, if any) +``` + +--- + +## 8. Interpret the results + +### 8.1 Stat files — what was injected + +Each `llfi.stat.fi.injectedfaults.-.txt` records exactly what happened +during one trial. Example: + +``` +FI stat: fi_type=bitflip, fi_max_multiple=-1, fi_index=12, fi_cycle=47, + fi_reg_index=0, fi_reg_pos=0, fi_reg_width=64, fi_bit=28, opcode=load +``` + +| Field | Meaning | +|-------|---------| +| `fi_type` | Fault type injected (`bitflip`, `stuck_at_0`, `stuck_at_1`) | +| `fi_index` | LLFI index of the targeted instruction (matches `factorial-llfi_index.ll`) | +| `fi_cycle` | Dynamic instruction count at the moment of injection | +| `fi_reg_index` | Which register of the instruction was targeted (0 = destination) | +| `fi_reg_pos` | Word position within a multi-word register (usually 0) | +| `fi_reg_width` | Register width in bits (32 or 64) | +| `fi_bit` | Which bit was flipped (0 = LSB) | +| `opcode` | LLVM IR opcode of the targeted instruction | +| `fi_max_multiple` | Number of faults injected; -1 = single fault | + +These values are sufficient to reproduce a specific run — copy them back into +`input.yaml` as pin fields (see `docs/input_yaml_guide.md` §Pinning a fault). + +### 8.2 Classifying outcomes + +Compare each run's `std_outputfile-run--` to `baseline/golden_std_output`: + +| Outcome | How to identify | +|---------|----------------| +| **Masked** | `std_outputfile` is identical to `golden_std_output`; `errorfile` absent | +| **SDC** (Silent Data Corruption) | `std_outputfile` differs from `golden_std_output`; `errorfile` absent | +| **Crash** | `errorfile` contains "Program crashed" or a signal number (e.g. -11 = SIGSEGV) | +| **Hang** | `errorfile` contains "Program timed out" or similar timeout message | + +Example `errorfile` for a crash: + +``` +Program crashed, terminated by the system, return code -11 +``` + +Example `errorfile` for a hang: + +``` +Program timed out +``` + +For factorial with input 6, most runs will be masked (the program is small and +the fault often hits an unused register) or will produce a different number +(SDC). Crashes are rare because the program does no pointer arithmetic. + +### 8.3 Quick manual comparison + +```bash +# Compare all runs against the golden output +for f in llfi/std_output/std_outputfile-run-0-*; do + echo -n "$f: " + if diff -q "$f" llfi/baseline/golden_std_output > /dev/null 2>&1; then + echo "MASKED" + else + echo "SDC or CRASH — check llfi/error_output/" + fi +done +``` + +--- + +## 9. Reproducing a specific run + +If run 0-2 produced an interesting result and you want to reproduce it exactly, +copy the stat fields from `llfi.stat.fi.injectedfaults.0-2.txt` back into +`input.yaml`: + +```yaml +runOption: + - run: + numOfRuns: 1 + fi_type: bitflip + fi_cycle: 47 + fi_index: 12 + fi_reg_index: 0 + fi_bit: 28 +``` + +Then re-run `injectfault`. The runtime reads these pin values and injects at +exactly the same point. + +--- + +## 10. Next steps + +- **Try a different selector**: change `insttype: include: [all]` to + `insttype: include: [fadd, fmul]` to target only floating-point arithmetic. +- **Try an ML model**: see `docs/tutorial_ml_experiment.md` for a complete + walkthrough of layer-targeted fault injection on a TensorFlow/ONNX model. +- **Add tracing**: set `tracingPropagation: True` and re-run to generate + per-instruction value traces in `llfi/llfi_stat_output/`. +- **Read the architecture**: `architecture.md` explains how the pass pipeline, + selectors, and runtime library fit together. diff --git a/docs/tutorial_ml_experiment.md b/docs/tutorial_ml_experiment.md new file mode 100644 index 00000000..2ae5ac7f --- /dev/null +++ b/docs/tutorial_ml_experiment.md @@ -0,0 +1,576 @@ +# Tutorial: Fault Injection for ML Models + +This tutorial walks you through a complete fault injection experiment on the +`mnist` sample — a convolutional neural network that classifies handwritten +digit images. By the end you will have: + +- converted a TensorFlow model through ONNX to LLVM IR +- instrumented the IR with LLTFI's layer-aware `CustomTensorOperator` selector +- profiled the model to obtain per-layer timing data +- injected multi-bit faults into a specific layer +- interpreted the injected-fault stat files and classified outcomes + +The ML pipeline has extra steps before instrumentation compared to the C/C++ +tutorial (`docs/tutorial_first_experiment.md`), because LLTFI works on LLVM IR +and ML frameworks do not emit it directly. + +**Prerequisites:** + +| Requirement | Notes | +|---|---| +| LLTFI built | `LLFI_BUILD_ROOT` set; see `README.md` | +| LLVM 20 tools on PATH | `clang`, `opt`, `llvm-link`, `mlir-translate` | +| onnx-mlir (LLTFI branch) | `ONNX_MLIR_SRC` and `ONNX_MLIR_BUILD` set; see `README.md` | +| json-c | `tools/json-c-setup.sh` installs it | +| Python packages | `pip install tensorflow tf2onnx onnx` | + +--- + +## 1. Set up environment variables + +```bash +export LLFI_BUILD_ROOT=/path/to/LLTFI-build +export ONNX_MLIR_SRC=/path/to/onnx-mlir +export ONNX_MLIR_BUILD=/path/to/onnx-mlir/build +export PATH=/usr/lib/llvm-20/bin:$ONNX_MLIR_BUILD/Debug/bin:$PATH +``` + +--- + +## 2. Copy the sample program to a working directory + +```bash +cp -r $LLFI_BUILD_ROOT/../sample_programs/ml_sample_programs/vision_models/mnist /tmp/mnist +cd /tmp/mnist +``` + +The directory contains: + +``` +mnist-cnn.py TensorFlow model definition (CNN trained on MNIST) +mnist-cnn-pytorch.py PyTorch alternative +image.c C driver: loads an image, runs the model, exports layer outputs to JSON +stb_image.h Single-header image loader used by image.c +model.onnx Pre-trained ONNX model (regenerate with compile.sh if needed) +eight.png five.png Test images (one per digit class provided) +nine.png seven.png +input.yaml Fault injection configuration +compile.sh Compilation pipeline (ONNX → LLVM IR) +runllfi.sh Instrumentation, profiling, and fault injection +clean.sh Removes all generated output files +``` + +--- + +## 3. The ML compilation pipeline + +Unlike a C/C++ program, an ML model requires several steps before LLTFI can +work on it. `compile.sh` automates all of them; this section explains each +step so you know what is produced and why. + +### 3.1 Step 0 (optional): train and export the model + +If `model.onnx` is already present, `compile.sh` skips this step. + +```bash +python3 mnist-cnn.py # train +python3 -m tf2onnx.convert --saved-model mnist-cnn.tf \ + --output model.onnx # export +``` + +`model.onnx` is a self-contained representation of the network graph and its +trained weights. + +### 3.2 Step 1: extend the ONNX model to expose intermediate layer outputs + +```bash +python3 $LLFI_BUILD_ROOT/../tools/ExtendONNXModel.py \ + --model_path ./model.onnx \ + --output_model_path ./extendedmodel.onnx > expected_op_seq.txt +``` + +Standard ONNX models expose only the final output. `ExtendONNXModel.py` adds +intermediate tensor outputs for every operator (conv, relu, maxpool, …) so +that `image.c` can export them to JSON during each run. This is needed later +by `CompareLayerOutputs.py` to pinpoint which layer a fault corrupted. + +`expected_op_seq.txt` receives the stdout of `ExtendONNXModel.py`: a +comma-separated list of operator position indices in execution order (e.g. +`0,1,2,3,4,5,6`). It is passed to the model binary at runtime so the driver +knows which outputs to save. + +### 3.3 Step 2: compile to MLIR with instrumentation hooks + +```bash +onnx-mlir --EmitLLVMIR extendedmodel.onnx \ + --instrument-onnx-ops="ALL" \ + --InstrumentBeforeAndAfterOp +``` + +`onnx-mlir` lowers the ONNX graph to MLIR and then to an LLVM IR dialect, +emitting `extendedmodel.onnx.mlir`. + +The two instrumentation flags are important for LLTFI: + +| Flag | Effect | +|---|---| +| `--instrument-onnx-ops="ALL"` | Inserts `@OMInstrumentPoint(operator_id, flag)` calls around every operator | +| `--InstrumentBeforeAndAfterOp` | Generates both a start call (`flag=2`) and end call (`flag=1`) for each operator boundary | + +`ProfilingPass` inserts calls to `lltfiMLLayer()` at these boundaries during +instrumentation. The runtime records the dynamic instruction cycle range +`[start, end]` for each operator, enabling LLTFI to confine fault injection to +a specific layer. + +### 3.4 Step 3: translate to LLVM IR + +```bash +mlir-translate -mlir-to-llvmir extendedmodel.onnx.mlir > model.mlir.ll +``` + +### 3.5 Step 4: compile the C driver and link everything + +```bash +clang -S -emit-llvm image.c -I$ONNX_MLIR_SRC/include -o main.ll +llvm-link -o model.ll -S main.ll model.mlir.ll +``` + +`image.c` is the program entry point: it reads the input image, calls +`run_main_graph()` (the compiled network), and writes the layer outputs to +`layeroutput.txt` in JSON format. `llvm-link` merges the driver IR and model +IR into a single `model.ll` that LLTFI instruments as one unit. + +Run all four steps at once: + +```bash +./compile.sh +``` + +--- + +## 4. Examine `input.yaml` + +```yaml +compileOption: + instSelMethod: + - customInstselector: + include: + - CustomTensorOperator + options: + - -layerNo=0;0;0;0;0;0;0 + - -layerName=conv;relu;matmul;maxpool;add;avgpool;softmax + + regSelMethod: regloc + regloc: dstreg + + includeInjectionTrace: + - forward + + tracingPropagation: False + + tracingPropagationOption: + maxTrace: 250 + debugTrace: False + mlTrace: False + generateCDFG: True + +runOption: + - run: + numOfRuns: 1000 + fi_type: bitflip + window_len_multiple_startindex: 1 + window_len_multiple_endindex: 500 + fi_max_multiple: 2 +``` + +### 4.1 The `CustomTensorOperator` selector + +`CustomTensorOperator` targets floating-point arithmetic instructions +(`fadd`, `fsub`, `fmul`, `fdiv`, `fcmp`) that fall inside the +`OMInstrumentPoint` boundary for the specified operators. It is the +recommended selector for layer-level ML fault injection. + +| `instSelMethod` | What it targets | +|---|---| +| `CustomTensorOperator` | FP arith inside named operator boundaries; needs `--instrument-onnx-ops` compilation | +| `maingraph` | All FP arith anywhere in `main_graph()`; no layer granularity | +| `insttype: include: [fadd, fmul]` | Any FP arith in the whole module; no layer or operator granularity | + +### 4.2 Layer targeting with `layerNo` and `layerName` + +Both options must be provided and must have the same number of semicolon-separated elements. + +| `layerNo` value | Meaning | +|---|---| +| `0` | All layers of the type in the matching `layerName` position | +| `N > 0` | Only the Nth occurrence of that layer type (1-indexed) | + +Examples: + +```yaml +# Inject into all conv and relu layers +options: + - -layerNo=0;0 + - -layerName=conv;relu + +# Inject only into the 2nd conv layer +options: + - -layerNo=2 + - -layerName=conv + +# Inject into every layer of every type +options: + - -layerNo=0 + - -layerName=all +``` + +Valid layer name values: `conv`, `relu`, `matmul`, `maxpool`, `add`, +`avgpool`, `loop`, `nonmaxs`, `unsqueeze`, `softmax`, `all`. + +### 4.3 Expanding the target set with `includeInjectionTrace` + +```yaml +includeInjectionTrace: + - forward +``` + +`forward` expands the injection candidate set to include all instructions that +are data-flow reachable from the selected operator's output (i.e. the forward +slice). `backward` includes all instructions that feed into the selected +operator's inputs (the backward slice). Omit this key to inject only into the +instructions directly inside the operator boundary. + +### 4.4 Multi-fault injection + +The MNIST `input.yaml` injects up to 2 faults per run, separated by a random +number of dynamic instructions drawn from `[1, 500]`: + +```yaml +fi_max_multiple: 2 +window_len_multiple_startindex: 1 +window_len_multiple_endindex: 500 +``` + +| Key | Meaning | +|---|---| +| `fi_max_multiple` | Maximum number of faults per run (≤ 100) | +| `window_len_multiple_startindex` | Lower bound on the inter-fault instruction gap | +| `window_len_multiple_endindex` | Upper bound on the inter-fault instruction gap | + +`fi_max_multiple` and `window_len` are mutually exclusive — use only one. + +For the full key reference see `docs/input_yaml_guide.md` and +`docs/input_masterlist_ml.yaml`. + +--- + +## 5. Instrument + +```bash +$LLFI_BUILD_ROOT/bin/instrument --readable \ + -L $ONNX_MLIR_BUILD/Debug/lib -lcruntime -ljson-c -lprotobuf \ + model.ll +``` + +The `-L` and `-l` flags link the onnx-mlir runtime libraries and json-c into +the instrumented binaries. + +`instrument` produces the `llfi/` directory: + +``` +llfi/ + model-profiling.exe Binary for the profiling pass + model-faultinjection.exe Binary for fault injection + model-profiling.ll Instrumented IR (profiling version) + model-faultinjection.ll Instrumented IR (fault injection version) + model-llfi_index.ll IR annotated with LLFI index numbers +``` + +Two top-level files are also written: + +| File | Contents | +|---|---| +| `llfi.log.compilation.txt` | Full pass output; check here if instrumentation fails | +| `llfi.config.compiletime.txt` | Summary of what was selected (failure class, layer, targets) | + +Check `llfi.config.compiletime.txt`: + +``` +failure_class=HardwareFault +failure_mode=CustomTensorOperator +targets= +injector= +``` + +If the file shows `0 candidate instructions`, either the operator names do not +match those in the model or the model was compiled without +`--instrument-onnx-ops`. + +--- + +## 6. Profile + +```bash +$LLFI_BUILD_ROOT/bin/profile \ + ./llfi/model-profiling.exe \ + eight.png \ + $(cat expected_op_seq.txt) +``` + +The arguments after the executable are passed to the model: +- `eight.png` — the input image +- `$(cat expected_op_seq.txt)` — the layer output sequence, so the driver knows + which intermediate tensors to save to `layeroutput.txt` + +Profiling runs the model once without injecting faults. It produces: + +| File | Contents | +|---|---| +| `llfi.stat.prof.txt` | Total cycle count, plus one `ml_layer` line per operator | +| `llfi.stat.totalindex.txt` | Number of unique injectable instruction indices | +| `llfi/baseline/golden_std_output` | Stdout of the fault-free run (predicted digit) | +| `layeroutput.txt` | Per-layer tensor values from the golden run (JSON format) | + +The ML-specific entries in `llfi.stat.prof.txt` look like: + +``` +total_cycle=4827613 +ml_layer=0,conv,12345,89012 +ml_layer=1,relu,89013,105400 +ml_layer=2,conv,105401,312800 +... +``` + +Each `ml_layer` line records: sequential layer number, operator type, start +dynamic-instruction cycle, end dynamic-instruction cycle. `injectfault` uses +this data to confine injection to cycles that fall within the requested layer. + +Move `layeroutput.txt` to a safe location before running fault injection, since +each faulty run will overwrite it: + +```bash +cp layeroutput.txt llfi/baseline/golden_layeroutput.txt +``` + +--- + +## 7. Inject faults + +```bash +$LLFI_BUILD_ROOT/bin/injectfault \ + ./llfi/model-faultinjection.exe \ + eight.png \ + $(cat expected_op_seq.txt) +``` + +This runs the model `numOfRuns` times (1000 in the sample `input.yaml`). +Each run draws a random cycle from within the layer timing ranges recorded +during profiling, injects up to `fi_max_multiple` bit-flips separated by a +random gap, and records the outcome. + +On completion, `llfi/` contains: + +``` +llfi/ + baseline/ + golden_std_output Reference output from profiling + golden_layeroutput.txt (you copied this manually above) + std_output/ + std_outputfile-run-0-0 Stdout from run 0 + std_outputfile-run-0-1 Stdout from run 1 + ... + error_output/ + errorfile-run-0-N Written only for crashed or timed-out runs + llfi_stat_output/ + llfi.stat.fi.injectedfaults.0-0.txt Injection details for run 0 + llfi.stat.fi.injectedfaults.0-1.txt + ... + prog_output/ + layeroutput.txt Layer output from the most recent run + (overwritten each run — save per-run if needed) +``` + +--- + +## 8. Interpret the results + +### 8.1 Injected fault stat files + +Each `llfi.stat.fi.injectedfaults.-.txt` records what happened in +one trial. For a multi-fault run it contains one line per injected fault: + +``` +FI stat: fi_type=bitflip, fi_max_multiple=2, fi_index=1042, fi_cycle=156203, + fi_reg_index=0, fi_reg_pos=0, fi_reg_width=32, fi_bit=17, opcode=fmul +ml_layer_name=conv +ml_layer_number=2 +FI stat: fi_type=bitflip, fi_max_multiple=2, fi_index=874, fi_cycle=156387, + fi_reg_index=0, fi_reg_pos=0, fi_reg_width=32, fi_bit=4, opcode=fadd +ml_layer_name=conv +ml_layer_number=2 +``` + +The fields from the C/C++ tutorial apply here too. The additional ML fields: + +| Field | Meaning | +|---|---| +| `ml_layer_name` | Operator type of the layer where the fault landed (`conv`, `relu`, …) | +| `ml_layer_number` | Sequential layer index (matches the `ml_layer=N,…` line in `llfi.stat.prof.txt`) | + +### 8.2 Classifying outcomes + +The outcome classification is the same as for C/C++ programs: + +| Outcome | How to identify | +|---|---| +| **Masked** | `std_outputfile` matches `golden_std_output` (same predicted digit); no `errorfile` | +| **SDC** | `std_outputfile` differs (wrong digit predicted); no `errorfile` | +| **Crash** | `errorfile` present with signal number (e.g. `-11` = SIGSEGV) | +| **Hang** | `errorfile` present with timeout message | + +Quick batch comparison: + +```bash +for f in llfi/std_output/std_outputfile-run-0-*; do + echo -n "$f: " + if diff -q "$f" llfi/baseline/golden_std_output > /dev/null 2>&1; then + echo "MASKED" + elif [ -f "llfi/error_output/errorfile-run-0-$(basename $f | sed 's/.*-//')" ]; then + echo "CRASH/HANG" + else + echo "SDC" + fi +done +``` + +For MNIST with 1000 runs you should expect a high masked rate (a single +bit-flip in a floating-point multiply is usually too small to change the +final argmax) with occasional SDC (wrong digit) and rare crashes. + +### 8.3 Layer-level analysis with `CompareLayerOutputs.py` + +To find which layer first produced a corrupted output, compare the JSON layer +outputs from a faulty run to the golden run. This requires saving +`layeroutput.txt` from each injection run before the next one overwrites it; +the simplest approach is to add a post-run copy step or reduce `numOfRuns` to 1 +when doing targeted investigation. + +```bash +# Run a single fault injection with the stat file pinning the fault +$LLFI_BUILD_ROOT/bin/injectfault \ + ./llfi/model-faultinjection.exe \ + eight.png \ + $(cat expected_op_seq.txt) + +# Compare layer outputs +python3 $LLFI_BUILD_ROOT/../tools/CompareLayerOutputs.py \ + --golden llfi/baseline/golden_layeroutput.txt \ + --faulty layeroutput.txt +``` + +`CompareLayerOutputs.py` prints the first layer whose output tensor differs +from the golden run, along with a summary of how many elements changed and by +how much. With `pygraphviz` installed it also writes a dot graph highlighting +the affected layers. + +--- + +## 9. Targeting a specific layer + +To restrict injection to a single layer, edit `input.yaml`. For example, to +inject only into the first convolutional layer: + +```yaml +compileOption: + instSelMethod: + - customInstselector: + include: + - CustomTensorOperator + options: + - -layerNo=1 + - -layerName=conv + + regSelMethod: regloc + regloc: dstreg + +runOption: + - run: + numOfRuns: 200 + fi_type: bitflip +``` + +Then delete the `llfi/` directory and re-run from instrumentation: + +```bash +rm -rf llfi llfi.stat.prof.txt llfi.stat.totalindex.txt \ + llfi.config.compiletime.txt llfi.log.compilation.txt +$LLFI_BUILD_ROOT/bin/instrument --readable \ + -L $ONNX_MLIR_BUILD/Debug/lib -lcruntime -ljson-c -lprotobuf \ + model.ll +$LLFI_BUILD_ROOT/bin/profile \ + ./llfi/model-profiling.exe eight.png $(cat expected_op_seq.txt) +$LLFI_BUILD_ROOT/bin/injectfault \ + ./llfi/model-faultinjection.exe eight.png $(cat expected_op_seq.txt) +``` + +Alternatively, `runllfi.sh` wraps these three steps (it deletes `llfi*/` first): + +```bash +./runllfi.sh +``` + +--- + +## 10. PyTorch path + +If you prefer PyTorch, use `compile-pytorch.sh` instead of `compile.sh`: + +```bash +./compile-pytorch.sh +``` + +The PyTorch path compiles without `--instrument-onnx-ops`, so the IR contains +no `OMInstrumentPoint` calls. This means: + +- `CustomTensorOperator` **cannot** be used — use `maingraph` or `insttype` instead +- No per-layer timing data in `llfi.stat.prof.txt` — faults are distributed + uniformly across all targeted instructions rather than being confined to a layer +- `CompareLayerOutputs.py` is not applicable because `expected_op_seq.txt` is + not generated and `layeroutput.txt` is not written + +Example `input.yaml` for PyTorch: + +```yaml +compileOption: + instSelMethod: + - customInstselector: + include: + - maingraph + + regSelMethod: regloc + regloc: dstreg + +runOption: + - run: + numOfRuns: 200 + fi_type: bitflip +``` + +--- + +## 11. Next steps + +- **Vary the layer**: change `layerName` and `layerNo` to compare fault + sensitivity across layers. +- **Vary the fault model**: change `fi_type` to `stuck_at_0` or `stuck_at_1` + to model permanent hardware faults instead of transient bit-flips. +- **Try a larger model**: `sample_programs/ml_sample_programs/vision_models/` + contains ResNet-50, VGG-16, GoogLeNet, and others — all use the same workflow. +- **Instruction duplication**: apply `SEDPasses.so` before instrumentation to + evaluate soft-error detection coverage. See + `llvm_passes/instruction_duplication/README.md`. +- **Batch across multiple models**: see `bin/batchInstrument.py`, + `batchProfile.py`, and `batchInjectfault.py` for running an experiment + campaign across many programs or fault modes in one call. +- **Read the architecture**: `architecture.md` §2.5 explains how + `CustomTensorOperatorInstSelector` and `OMInstrumentPoint` work together. diff --git a/installer/.gitignore b/installer/.gitignore deleted file mode 100644 index 57afa12e..00000000 --- a/installer/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -downloads/* -llfi/* -llfisrc/* -llvm/* -llvmsrc/* -pyyaml/* -pyyamlsrc/* \ No newline at end of file diff --git a/installer/InstallLLTFI.py b/installer/InstallLLTFI.py deleted file mode 100644 index 425f0647..00000000 --- a/installer/InstallLLTFI.py +++ /dev/null @@ -1,332 +0,0 @@ -import sys, os -import subprocess -if sys.version_info >= (3,): - import urllib.request as urllib2 - import urllib.parse as urlparse -else: - import urllib2 - import urlparse -import argparse - -LLTFIROOTDIRECTORY = "." - -def Touch(path): - with open(path, 'a'): - os.utime(path, None) - -# Helper function to return the python version installed in the system -def python3PrintParse(version): - return version.split()[1] - -# Get the Python3 major version and minor version -def python3Parse(version): - return version.split()[1].split('.')[:2] - -# Error Message to display if appropriate Python3 version is not found -python3Msg = "Error: Python 3 (python3) not found on path" + \ - " Pease ensure python3 is installed and is available on the path" + \ - " The latest version of Python3 can be downloaded from:" + \ - " https://www.python.org/downloads/" - -# Helper function to return the cmake version installed in the system -def CmakePrintParse(version): - return version.split()[2] - -# Get the cmake major version and minor version -def CmakeParse(version): - return version.split()[2].split('.')[:2] - -# Error Message to display if the appropriate cmake version is not found -cmakeMsg = "\tCmake 3.15+ can be downloaded from:\n\thttp://www.cmake.org/cmake/resources/software.html" - -# Helper function to return the ninja version installed in the system -def ninjaPrintParse(version): - return version.split()[0] - -# Get the ninja major version and minor version -def ninjaParse(version): - return version.split()[0].split('.')[:2] - -# Error Message to display if the appropriate ninja version is not found -ninjaMsg = "\tNinja 1.10+ can be downloaded from:\n\thttps://ninja-build.org/" - -# Helper function to return the pip version installed in the system -def pipPrintParse(version): - return version.split()[1] - -# Get the pip major version and minor version -def pipParse(version): - return version.split()[1].split('.')[:2] - -# Error Message to display if the appropriate pip version is not found -pipMsg = "\tPIP missing. Please install pip: apt-get install -y python3-pip" - -# Helper function to check if dependency versions present in the system are correct -def checkDep(name, execName, versionArg, printParseFunc, parseFunc, minVersion, msg): - try: - which = subprocess.check_output(['which', execName]) - print("Success: " + name + " Found at: " + str(which.strip()).lstrip("b\'").rstrip("\'")) - version = str(subprocess.check_output([execName, versionArg], stderr=subprocess.STDOUT).strip()) - version = version.lstrip("b'").rstrip('\'').replace('\\n',' ') - print("v", version) - try: - printVersion = str(printParseFunc(version)) - #print("pv", printVersion) - version = parseFunc(str(version).strip()) - #print("cv", version) - properVersion = True - - if int(version[0]) < minVersion[0]: - properVersion = False - elif (int(version[0]) == minVersion[0]) and (int(version[1]) < minVersion[1]): - properVersion = False - if properVersion: - print("Success: " + name + "(" + printVersion + ") is at or above version " + ".".join([str(x) for x in minVersion])) - return True - else: - print("Error: " + name + "(" + printVersion + ") is below version " + ".".join([str(x) for x in minVersion])) - print(msg) - return False - except: - print("Warning, " + name + " detected on path, but unable to parse version info.") - print("Please ensure that " + name + " is at least of version: " + '.'.join([str(x) for x in minVersion])) - return True - except(subprocess.CalledProcessError): - print("Error: " + name + " (" + execName + ") not found on path") - print(" Pease ensure " + name + " is installed and is available on the path") - print(msg) - return False - -# Check if all dependencies(Python >= v3.8, CMake > v3.15, Ninja > 1.10.2) are installed. Exit the program with an error message if dependencies are missing. -# Install PyYaml, tensorflow and tfonnx packages if not already installed -def checkDependencies(): - hasAll = True - hasAll = checkDep("Python 3", "python3", "--version", python3PrintParse, python3Parse, [3,8], python3Msg) and hasAll - hasAll = checkDep("Cmake","cmake","--version", CmakePrintParse, CmakeParse, [3,15], cmakeMsg) and hasAll - hasAll = checkDep("Ninja","ninja","--version", ninjaPrintParse, ninjaParse, [1,10], ninjaMsg) and hasAll - hasAll = checkDep("pip3","pip3","--version", pipPrintParse, pipParse, [20,0], pipMsg) and hasAll - - if not hasAll: - return hasAll - - try: - import pkg_resources - versionInfo = pkg_resources.get_distribution("pyyaml").version - if versionInfo[0] != '5' or versionInfo[2] != '4' or versionInfo[4] != '1': - print("Incorrect PyYaml version. Please install v5.4.1") - sys.exit(-1) - except: - print("PyYaml missing. Installing PyYaml..") - os.system("pip3 install pyyaml===5.4.1") - - try: - import tensorflow as tf - versionInfo = tf.__version__ - if versionInfo[0] != '2': - print("Incorrect tensorflow version. Please install v2.0 or greater") - sys.exit(-1) - except ImportError: - print("Tensorflow missing. Installing Tensorflow..") - os.system("pip install tensorflow") - - try: - import tf2onnx as tfonnx - except ImportError: - print("tf2onnx missing. Installing tf2onnx..") - os.system("pip install tf2onnx") - - return hasAll - -# Clone the LLVM, ONNX-MLIR and LLTFI repos -def downloadSource(): - # LLTFI - os.system('git clone https://github.com/DependableSystemsLab/LLTFI.git') - - # LLVM - os.system('git clone https://github.com/llvm/llvm-project.git') - os.chdir('llvm-project') - os.system('git checkout 9778ec057cf4') - os.chdir(os.path.join('..')) - - # onnx-mlir - os.system('git clone --recursive https://github.com/DependableSystemsLab/onnx-mlir-lltfi.git') - os.chdir('onnx-mlir-lltfi') - os.system('git checkout LLTFI') - os.chdir(os.path.join('..')) - os.rename('onnx-mlir-lltfi','onnx-mlir') - -# Helper function to check if a particular directory already exists -def CheckDirExists(dir): - FullPath = os.path.abspath(dir) - if (os.path.exists(FullPath)): - if (os.path.isdir(FullPath)): - print("%s directory exists." % (dir)) - return True - else: - print("%s directory missing" % (dir)) - return False - -# Build and install LLVM 15.0, onnx-mlir and LLTFI -def buildSource(): - # Build and install LLVM - if CheckDirExists('llvm-project'): - os.chdir("llvm-project") - if (not CheckDirExists('build')): - os.mkdir("build") - os.chdir("build") - if (not os.path.exists("CMAKESUCCESS")): - print("Running cmake for LLVM:") - p = subprocess.call(["cmake", "-G", "Ninja", "../llvm", "-DLLVM_ENABLE_PROJECTS='clang;mlir'","-DLLVM_BUILD_TESTS=ON", "-DLLVM_TARGETS_TO_BUILD='host'", "-DLLVM_ENABLE_ASSERTIONS=ON", "-DLLVM_ENABLE_RTTI=ON"]) - p1 = subprocess.call(["cmake", "--build", ".", "--target", "clang", "check-mlir", "mlir-translate", "opt", "llc", "lli", "llvm-dis", "llvm-link", "-j1" ]) - if p != 0 or p1 != 0: - sys.exit(p) - Touch("CMAKESUCCESS") - - if (not os.path.exists("MAKESUCCESS")): - print("Running make for LLVM") - p = subprocess.call(["ninja", "install", "-j1"]) - if p != 0: - sys.exit(p) - Touch("MAKESUCCESS") - os.chdir(os.path.join('../..')) - - - else: - print("LLVM source code missing. Run the installer script without -nD option") - - # Build and install ONNX-MLIR - if CheckDirExists('onnx-mlir'): - cwd = os.getcwd() - os.environ['MLIR_DIR'] = cwd + '/llvm-project/build/lib/cmake/mlir' - os.chdir("onnx-mlir") - if (not CheckDirExists('build')): - os.mkdir("build") - os.chdir("build") - if (not os.path.exists("CMAKESUCCESS")): - print("Running cmake for ONNX-MLIR:") - p = subprocess.call(["cmake", "-G", "Ninja", "-DCMAKE_CXX_COMPILER=/usr/bin/c++", "-DMLIR_DIR=${MLIR_DIR}", ".."]) - p1 = subprocess.call(["cmake", "--build", ".", "-j1" ]) - if p != 0 or p1 != 0: - sys.exit(p) - Touch("CMAKESUCCESS") - - if (not os.path.exists("MAKESUCCESS")): - print("Running make for ONNX-MLIR") - p = subprocess.call(["ninja", "install", "-j1"]) - if p != 0: - sys.exit(p) - Touch("MAKESUCCESS") - os.chdir(os.path.join('../..')) - - else: - print("ONNX-MLIR source missing. Run the installer script without -nD option") - - - # Build LLTFI - if CheckDirExists('LLTFI'): - os.chdir("LLTFI") - if (not os.path.exists("BUILDSUCCESS")): - print("Building LLTFI:") - p = os.system('./setup -LLFI_BUILD_ROOT $(pwd)/build -LLVM_SRC_ROOT $(pwd)/../llvm-project -LLVM_DST_ROOT $(pwd)/../llvm-project/build') - if p != 0: - sys.exit(p) - Touch("BUILDSUCCESS") - os.chdir(os.path.join('..')) - - else: - print("LLTFI source missing. Run the installer script without -nD option") - - cwd = os.getcwd() - os.environ['LLFI_BUILD_ROOT'] = cwd + '/LLTFI/build' - -# Helper function to download a zip file from a url -def DownloadFile(url, destinationDirectory, desc=None): - u = urllib2.urlopen(url) - - scheme, netloc, path, query, fragment = urlparse.urlsplit(url) - filename = os.path.basename(path) - if not filename: - filename = 'downloaded.file' - if desc: - filename = os.path.join(desc, filename) - - with open(os.path.join(destinationDirectory, filename), 'wb') as f: - meta = u.info() - meta_func = meta.getheaders if hasattr(meta, 'getheaders') else meta.get_all - meta_length = meta_func("Content-Length") - file_size = None - if meta_length: - file_size = int(meta_length[0]) - print("Downloading: {0} Bytes: {1}".format(url, file_size)) - - file_size_dl = 0 - block_sz = 8192 - while True: - buffer = u.read(block_sz) - if not buffer: - break - - file_size_dl += len(buffer) - f.write(buffer) - - status = "{0:16}".format(file_size_dl) - if file_size: - status += " [{0:6.2f}%]".format(file_size_dl * 100 / file_size) - status += chr(13) - print(status, end="") - print() - - return filename - -# Download and install libprotoc v.3.17.2 -def downloadAndInstallProtobuf(): - DownloadFile("https://github.com/protocolbuffers/protobuf/releases/download/v3.17.2/protobuf-all-3.17.2.zip", ".") - os.system("unzip protobuf-all-3.17.2.zip") - os.chdir("protobuf-3.17.2") - os.system("./configure") - os.system("make -j2") - os.system("make check -j2") - os.system("make install") - os.system("ldconfig") - os.chdir(os.path.join('../..')) - -# Run LLTFI regression tests -def runTests(): - LLFI_BUILD_DIR = os.path.dirname(os.path.realpath(__file__)) - subprocess.call(["python3", LLFI_BUILD_DIR + "/LLTFI/build/test_suite/SCRIPTS/llfi_test", "--all", "--threads", "2", "--verbose"]) - -parser = argparse.ArgumentParser( - description=("Installer for UBC DependableSystemsLab's LLTFI"), - epilog="More information available at www.github.com/DependableSystemsLab/LLTFI", - usage='%(prog)s [options]') -parser.add_argument("-v", "--version", action="version", version="LLTFI Installer v0.1, September 23, 2022") -parser.add_argument("-sDC", "--skipDependencyCheck", action='store_true', help="Skip Dependency Checking") -parser.add_argument("-nPb", "--noProtobuf", action='store_true', help="Do not download and install Protobuf") -parser.add_argument("-nD", "--noDownload", action='store_true', help="Do not download any files") -parser.add_argument("-nB", "--noBuild", action='store_true', help="Do not perform installation, only download") -parser.add_argument("-rT", "--runTests", action='store_true', help="Run all regression tests for LLTFI after installation") - - -if __name__ == "__main__": - args = parser.parse_args(sys.argv[1:]) - if not args.skipDependencyCheck: - print("Checking LLTFI Pre-Requisites and Dependencies") - deps = checkDependencies() - if not deps: - print("Some LLTFI Pre-Requisites are missing!") - print("Please see Errors above, and install the missing dependencies") - print("Exiting Installer...") - sys.exit(-1) - # If the "-nPb" option is not specified, download and install protobuf - if not args.noProtobuf: - downloadAndInstallProtobuf() - print("Installing LLTFI to: " + os.path.abspath(LLTFIROOTDIRECTORY)) - # If "-nD" option is not specifies, clone all the required github repositories - if not args.noDownload: - downloadSource() - # If "-nB" option is not specified, Build and install - if not args.noBuild: - buildSource() - # Run LLTFI regression tests - if args.runTests: - runTests() diff --git a/lint.sh b/lint.sh new file mode 100755 index 00000000..f5b58057 --- /dev/null +++ b/lint.sh @@ -0,0 +1,174 @@ +#!/usr/bin/env bash +# lint.sh — Run all LLTFI linters (clang-tidy, clang-format, flake8). +# +# Usage: +# ./lint.sh # check only (exit 1 if any issues found) +# ./lint.sh --fix # auto-fix clang-format and flake8 issues in-place +# ./lint.sh --cpp # C++ checks only +# ./lint.sh --python # Python checks only +# ./lint.sh --install # install missing Python tools then lint +# +# Requirements: +# C++: clang-tidy-20 and clang-format-20 (apt install clang-tidy-20 clang-format-20) +# compile_commands.json in LLTFI-build/ +# (add -DCMAKE_EXPORT_COMPILE_COMMANDS=ON to cmake to generate it) +# Python: flake8 and flake8-bugbear (pip install flake8 flake8-bugbear) + +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${LLFI_BUILD_ROOT:-${REPO_ROOT}/../LLTFI-build}" +FIX=0 +RUN_CPP=1 +RUN_PYTHON=1 +ERRORS=0 + +# --------------------------------------------------------------------------- +# Parse arguments +# --------------------------------------------------------------------------- +for arg in "$@"; do + case "$arg" in + --fix) FIX=1 ;; + --cpp) RUN_PYTHON=0 ;; + --python) RUN_CPP=0 ;; + --install) + echo "==> Installing Python lint tools..." + pip3 install --quiet flake8 flake8-bugbear + ;; + --help|-h) + head -15 "$0" | grep '^#' | sed 's/^# \?//' + exit 0 + ;; + *) + echo "Unknown argument: $arg" >&2 + exit 1 + ;; + esac +done + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +find_tool() { + # find_tool [fallback...] + for t in "$@"; do + if command -v "$t" &>/dev/null; then echo "$t"; return 0; fi + done + return 1 +} + +section() { echo; echo "==> $*"; } +pass() { echo " PASS: $*"; } +fail() { echo " FAIL: $*"; ERRORS=$((ERRORS + 1)); } +warn() { echo " WARN: $* (tool not found — skipping)"; } + +# --------------------------------------------------------------------------- +# C++ linting +# --------------------------------------------------------------------------- +if [[ $RUN_CPP -eq 1 ]]; then + + # clang-format + section "clang-format (C++ formatting)" + if CFMT=$(find_tool clang-format-20 clang-format); then + CPP_FILES=$(find "${REPO_ROOT}/llvm_passes" -name '*.cpp' -o -name '*.h' ) + FMT_ISSUES=0 + for f in $CPP_FILES; do + if [[ $FIX -eq 1 ]]; then + "$CFMT" -i "$f" + else + diff_out=$("$CFMT" --dry-run --Werror "$f" 2>&1 || true) + if [[ -n "$diff_out" ]]; then + echo " $f: formatting issues (run lint.sh --fix to auto-fix)" + FMT_ISSUES=$((FMT_ISSUES + 1)) + fi + fi + done + if [[ $FIX -eq 1 ]]; then + pass "reformatted all C++ files" + elif [[ $FMT_ISSUES -eq 0 ]]; then + pass "all C++ files are correctly formatted" + else + fail "$FMT_ISSUES file(s) have formatting issues" + fi + else + warn "clang-format-20 not found (apt install clang-format-20)" + fi + + # clang-tidy + section "clang-tidy (C++ static analysis)" + if CTIDY=$(find_tool clang-tidy-20 clang-tidy); then + COMPDB="${BUILD_DIR}/compile_commands.json" + if [[ ! -f "$COMPDB" ]]; then + warn "compile_commands.json not found at ${COMPDB}." \ + "Rebuild with: cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ..." + else + TIDY_ISSUES=0 + CPP_SOURCES=$(find "${REPO_ROOT}/llvm_passes" -name '*.cpp') + for f in $CPP_SOURCES; do + result=$("$CTIDY" -p "$BUILD_DIR" --quiet "$f" 2>&1 \ + | grep -v "^[0-9]* warnings generated\.$" || true) + if [[ -n "$result" ]]; then + echo "$result" + TIDY_ISSUES=$((TIDY_ISSUES + 1)) + fi + done + if [[ $TIDY_ISSUES -eq 0 ]]; then + pass "no clang-tidy issues found" + else + fail "$TIDY_ISSUES file(s) have clang-tidy issues" + fi + fi + else + warn "clang-tidy-20 not found (apt install clang-tidy-20)" + fi + +fi + +# --------------------------------------------------------------------------- +# Python linting +# --------------------------------------------------------------------------- +if [[ $RUN_PYTHON -eq 1 ]]; then + + section "flake8 (Python style and correctness)" + if python3 -m flake8 --version &>/dev/null; then + PYTHON_DIRS=( + "${REPO_ROOT}/bin" + "${REPO_ROOT}/test_suite/SCRIPTS" + "${REPO_ROOT}/tools/GenerateMakefile" + ) + # Filter to directories that actually exist + EXISTING_DIRS=() + for d in "${PYTHON_DIRS[@]}"; do + [[ -d "$d" ]] && EXISTING_DIRS+=("$d") + done + + if [[ ${#EXISTING_DIRS[@]} -gt 0 ]]; then + if python3 -m flake8 "${EXISTING_DIRS[@]}"; then + pass "no flake8 issues found" + else + fail "flake8 reported issues" + fi + fi + + # Check that flake8-bugbear is installed (catches bare except:, shell=True) + if ! python3 -m flake8 --select=B --quiet /dev/null 2>/dev/null; then + warn "flake8-bugbear not installed — bare-except and shell=True checks skipped" \ + "(pip install flake8-bugbear)" + fi + else + warn "flake8 not found (pip install flake8 flake8-bugbear)" + fi + +fi + +# --------------------------------------------------------------------------- +# Summary +# --------------------------------------------------------------------------- +echo +if [[ $ERRORS -eq 0 ]]; then + echo "==> All lint checks passed." + exit 0 +else + echo "==> $ERRORS lint check(s) failed." + exit 1 +fi diff --git a/llvm_passes/CMakeLists.txt b/llvm_passes/CMakeLists.txt index 77c5afcd..307e38ba 100644 --- a/llvm_passes/CMakeLists.txt +++ b/llvm_passes/CMakeLists.txt @@ -1,6 +1,10 @@ include(../config/llvm_passes.cmake) -set(LLVM_PASSES_DIRS_LLFI hardware_failures core software_failures) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(LLVM_PASSES_DIRS_LLFI hardware_failures core) include_directories(${LLVM_PASSES_DIRS_LLFI}) @@ -8,7 +12,6 @@ include_directories(${LLVM_PASSES_DIRS_LLFI}) add_llvm_library(llfi-passes MODULE SampleFIInstSelector.cpp SampleFIRegSelector.cpp - SoftwareFailureAutoScanPass.cpp HardwareFailureAutoScanPass.cpp MainGraphInstSelector.cpp CustomTensorOperatorInstSelector.cpp diff --git a/llvm_passes/CustomTensorOperatorInstSelector.cpp b/llvm_passes/CustomTensorOperatorInstSelector.cpp index 7bd50746..7d8e4ac8 100644 --- a/llvm_passes/CustomTensorOperatorInstSelector.cpp +++ b/llvm_passes/CustomTensorOperatorInstSelector.cpp @@ -1,55 +1,75 @@ -#include "llvm/IR/Instructions.h" -#include "llvm/Support/CommandLine.h" - -#include "FIInstSelector.h" #include "FICustomSelectorManager.h" +#include "FIInstSelector.h" #include "Utils.h" -#include "FICustomSelectorManager.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/CommandLine.h" -#include -#include #include +#include #include +#include #include -#include +#include using namespace llvm; namespace llfi { - -static cl::list< std::string > layerNo("layerNo", cl::desc("Layer Number in \ +static cl::list layerNo("layerNo", cl::desc("Layer Number in \ which you want to inject bitflip faults. Pass 0 for injecting faults in all the \ -layers.\n Semi-colon seperated values. Example: 1;0;2"), cl::ZeroOrMore); +layers.\n Semi-colon seperated values. Example: 1;0;2"), + cl::ZeroOrMore); -static cl::list< std::string > layerName("layerName", cl::desc("Layer Name in \ +static cl::list layerName("layerName", cl::desc("Layer Name in \ which you want to inject bitflip faults. Semi-colon seperated values. Example: \ -Conv;Relu;Pool"), cl::ZeroOrMore); - +Conv;Relu;Pool"), + cl::ZeroOrMore); // Return an array our of string of comma-seperated values. std::vector getCommaSeperateVals(std::string inp) { - std::string s = inp; - std::string delimiter = ";"; - std::vector retval; - size_t pos = 0; + std::string s = inp; + std::string delimiter = ";"; + std::vector retval; + size_t pos = 0; - std::string token; - while ((pos = s.find(delimiter)) != std::string::npos) { - token = s.substr(0, pos); - retval.push_back(token); - s.erase(0, pos + delimiter.length()); - } + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + retval.push_back(token); + s.erase(0, pos + delimiter.length()); + } - if ((pos = s.find(delimiter)) == std::string::npos) { - retval.push_back(s); - } + retval.push_back(s); - return retval; + return retval; } +// Return the ONNX operator name as a string. +std::string extractONNXOperatorName(Value *V) { + auto *GV = dyn_cast(V); + Constant *Init = GV->getInitializer(); + + auto *CDA = dyn_cast(Init); + if (!CDA || !CDA->isString()) { + return ""; + } + + StringRef onnxOpNameStrRef = CDA->getAsString(); + std::string onnxOpNameStr = onnxOpNameStrRef.str(); + + std::transform(onnxOpNameStr.begin(), onnxOpNameStr.end(), + onnxOpNameStr.begin(), + [](unsigned char c) { return std::tolower(c); }); + + std::string onnxOpPrefix = "onnx."; + size_t pos = onnxOpNameStr.find(onnxOpPrefix); + std::string result = onnxOpNameStr.substr(pos + onnxOpPrefix.length()); + return result; +} /** * This sample instruction selector only selects instructions in function @@ -58,219 +78,198 @@ std::vector getCommaSeperateVals(std::string inp) { class CustomTensorOperatorInstSelector : public HardwareFIInstSelector { public: - // Data structure to keep track of every Tensor Operator. - struct Operator { - - std::string OperatorName; - // ONNX mlir assigns a unique ID to every operator. - int64_t OperatorNumber; - // Number of times we have seen this operator. - int OperatorCount; - // Operator number to do FI - int FIOperatorCount; - - // Get unique Id corresponding to the ONNX operator. - static int64_t getOperatorNumber(std::string name) { - - char opname[100]; - std::transform(name.begin(), name.end(), name.begin(), - [](unsigned char c){ return std::tolower(c); }); - - strcpy(opname, name.c_str()); - - std::cout<<"OperatorName: "< ONNXOperatorId = { - {"conv", 1986948931}, - {"relu", 1970038098}, - {"maxpool", 30521821366870349}, - {"matmul", 119251066446157}, - {"add", 6579265}, - {"avgpool", 30521821365761601}, - {"softmax", 33884119937478483}, - {"loop", 1886351180}, - {"nonmaxs", 23494782373228366}, - {"unsqueeze", 28540527736745557} - }; - - if (ONNXOperatorId.find(opname) == ONNXOperatorId.end()) - return -1; - - return ONNXOperatorId[opname]; - } - - Operator(std::string name, std::string count) { + // Data structure to keep track of every Tensor Operator. + struct Operator { - OperatorName = name; - FIOperatorCount = atoll(count.c_str()); - OperatorCount = 0; - OperatorNumber = getOperatorNumber(name); + std::string OperatorName; + // Number of times we have seen this operator. + int OperatorCount; + // Operator number to do FI + int FIOperatorCount; - if (OperatorNumber == -1) { - std::cout<<"Operator name "<< OperatorName.c_str() << - " not found.\n"; - std::cout<<"Please use the following operator name(s):\ - conv, relu, maxpool, matmul, add, avgpool, all, and softmax."; - assert(false && "Invalid input operator name"); - } - - assert(FIOperatorCount >= 0 && "Invalid input FI operator number"); - } - - bool doFaultInjection(){ + // Check if provided ONNX operator is valid and supported. + bool isValidOperator(std::string name) { - OperatorCount++; + std::vector ONNXOperators = { + "conv", "relu", "maxpool", "matmul", "add", + "avgpool", "softmax", "loop", "nonmaxs", "unsqueeze"}; - // Inject fault in the user-specified operator count. - if (FIOperatorCount == 0 || FIOperatorCount == OperatorCount) - return true; - - return false; - } - }; // End of struct Operator. - -private: - bool isCustomTensorOperator; - std::unordered_map> map; - bool injectInAll; - - // Add Metadata to LLVM instructions; Only for debugging purposes! - void addMetadata(llvm::Instruction *ins, char *st = NULL){ - LLVMContext& C = ins->getContext(); - MDNode* N = MDNode::get(C, MDString::get(C, (!st) ? "t" : st)); - ins->setMetadata("Debug", N); + return (std::find(ONNXOperators.begin(), ONNXOperators.end(), name) != + ONNXOperators.end()); } - // Initializes Layer name and number - void initializeLayerNameAndNumber(std::string layerNo, - std::string layerName) { - - std::vector OperatorNames = - getCommaSeperateVals(layerName); - std::vector OperatorNumbers = - getCommaSeperateVals(layerNo); + Operator(std::string name, std::string count) { - assert(OperatorNumbers.size() == OperatorNames.size() && - "Number of CSVs given to the layerNo and layerName should be equal"); + OperatorName = name; + FIOperatorCount = (int)atoll(count.c_str()); + OperatorCount = 0; - for (int i = 0; i < OperatorNames.size(); i++) { + if (!isValidOperator(OperatorName)) { + std::cout << "Operator name " << OperatorName << " not found.\n"; + std::cout << "Please use the following operator name(s):\ + conv, relu, maxpool, matmul, add, avgpool, all, and softmax."; + assert(false && "Invalid input operator name"); + } - std::string name = OperatorNames[i]; - std::string number = OperatorNumbers[i]; + assert(FIOperatorCount >= 0 && "Invalid input FI operator number"); + } - // Inject in all operators. - if (strcmp(name.c_str(), "all") == 0 || - strcmp(name.c_str(), "All") == 0) { - injectInAll = true; - break; - } + bool doFaultInjection() { - int64_t code = Operator::getOperatorNumber(name); + OperatorCount++; - // if this operator is already in the map - if (map.find(code) != map.end()) { + // Inject fault in the user-specified operator count. + if (FIOperatorCount == 0 || FIOperatorCount == OperatorCount) + return true; - Operator *temp = new Operator(name, number); - map[code].push_back(temp); - } - else { + return false; + } + }; // End of struct Operator. - std::vector OpArr; - Operator *temp = new Operator(name, number); - OpArr.push_back(temp); - map.insert(make_pair(code, OpArr)); - } - } +private: + bool isCustomTensorOperator; + std::unordered_map> map; + bool injectInAll; + int64_t instrumentPoint; + + // Add Metadata to LLVM instructions; Only for debugging purposes! + void addMetadata(llvm::Instruction *ins, const char *st = nullptr) { + LLVMContext &C = ins->getContext(); + MDNode *N = MDNode::get(C, MDString::get(C, (!st) ? "t" : st)); + ins->setMetadata("Debug", N); + } + + // Initializes Layer name and number + void initializeLayerNameAndNumber(std::string layerNo, + std::string layerName) { + + std::vector OperatorNames = getCommaSeperateVals(layerName); + std::vector OperatorNumbers = getCommaSeperateVals(layerNo); + + assert(OperatorNumbers.size() == OperatorNames.size() && + "Number of CSVs given to the layerNo and layerName should be equal"); + + for (int i = 0; i < (int)OperatorNames.size(); i++) { + + const std::string &name = OperatorNames[i]; + const std::string &number = OperatorNumbers[i]; + + // Inject in all operators. + if (strcmp(name.c_str(), "all") == 0 || + strcmp(name.c_str(), "All") == 0) { + injectInAll = true; + break; + } + + // if this operator is already in the map + if (map.find(name) != map.end()) { + + Operator *temp = new Operator(name, number); + map[name].push_back(temp); + } else { + + std::vector OpArr; + Operator *temp = new Operator(name, number); + OpArr.push_back(temp); + map.insert(make_pair(name, OpArr)); + } } + } - bool shouldInjectFault(int64_t number) { + bool shouldInjectFault(std::string opName) { - if (injectInAll) return true; + if (injectInAll) + return true; - // If the operator isn't present in the map. - if (map.find(number) == map.end()) return false; - else { + // If the operator isn't present in the map. + if (map.find(opName) == map.end()) + return false; + else { - std::vector temp = map[number]; - bool result; + std::vector temp = map[opName]; + bool result = false; - for (auto it : temp) { - result |= it->doFaultInjection(); - } + for (auto it : temp) { + result |= it->doFaultInjection(); + } - return result; - } + return result; } + } - virtual bool isInstFITarget(Instruction *inst) { - if (inst->getParent()->getParent()->getName() == "main_graph") { + bool isInstFITarget(Instruction *inst) override { + if (inst->getParent()->getParent()->getName().starts_with("main_graph")) { - if (map.size() == 0 && !injectInAll){ - initializeLayerNameAndNumber(layerNo[0], layerName[0]); - } + if (map.empty() && !injectInAll) { + initializeLayerNameAndNumber(layerNo[0], layerName[0]); + } - if (inst->getOpcode() == Instruction::Call){ - CallInst* callinst = dyn_cast(inst); + if (inst->getOpcode() == Instruction::Call) { + CallInst *callinst = cast(inst); - // If this is OMInstrument function? - if ((callinst->getCalledFunction())->getName() == - "OMInstrumentPoint") { + // If this is OMInstrument function? + if (callinst->getCalledFunction() && + callinst->getCalledFunction()->getName() == "OMInstrumentPoint") { - Value* arg1 = callinst->getArgOperand(0); - Value* arg2 = callinst->getArgOperand(1); + Value *arg1 = callinst->getArgOperand(0); + std::string onnxOpName = extractONNXOperatorName(arg1); - ConstantInt* ci1 = dyn_cast(arg1); - ConstantInt* ci2 = dyn_cast(arg2); + Value *arg2 = callinst->getArgOperand(1); - int64_t argValue1 = ci1->getSExtValue(); - int64_t argValue2 = ci2->getSExtValue(); + ConstantInt *ci = dyn_cast(arg2); + if (onnxOpName == "" || !ci) + return false; + + int64_t argValue2 = ci->getSExtValue(); - if (argValue2 == 1 && shouldInjectFault(argValue1)) { + if (instrumentPoint == 0 && shouldInjectFault(onnxOpName)) { - // Inject fault! - isCustomTensorOperator = true; - } + // Inject fault! + isCustomTensorOperator = true; + instrumentPoint = argValue2; + } - if (argValue2 == 2) { + if (argValue2 == instrumentPoint + 1) { - // Set this to false after the operator ends. - isCustomTensorOperator = false; - } - } - } + // Set this to false after the operator ends. + isCustomTensorOperator = false; + instrumentPoint = 0; + } + } + } - if (!isCustomTensorOperator) return false; + if (!isCustomTensorOperator) + return false; - // Injecting fault. - if (inst->getOpcode() == Instruction::FAdd || - inst->getOpcode() == Instruction::FSub || - inst->getOpcode() == Instruction::FMul || - inst->getOpcode() == Instruction::FDiv || - inst->getOpcode() == Instruction::FCmp) { + // Injecting fault. + if (inst->getOpcode() == Instruction::FAdd || + inst->getOpcode() == Instruction::FSub || + inst->getOpcode() == Instruction::FMul || + inst->getOpcode() == Instruction::FDiv || + inst->getOpcode() == Instruction::FCmp) { - addMetadata(inst, "Injected fault"); - return true; - } + addMetadata(inst, "Injected fault"); + return true; + } - return false; // Inject Fault in all instructions - } - return false; + return false; // Inject Fault in all instructions } + return false; + } public: - CustomTensorOperatorInstSelector(){ - isCustomTensorOperator = false; - injectInAll = false; - } - - virtual void getCompileTimeInfo(std::map &info) { - info["failure_class"] = "HardwareFault"; - info["failure_mode"] = "CustomTensorOperator"; - info["targets"] = " &info) override { + info["failure_class"] = "HardwareFault"; + info["failure_mode"] = "CustomTensorOperator"; + info["targets"] = ""; - info["injector"] = ""; - } + info["injector"] = ""; + } }; static RegisterFIInstSelector X("CustomTensorOperator", diff --git a/llvm_passes/HardwareFailureAutoScanPass.cpp b/llvm_passes/HardwareFailureAutoScanPass.cpp index 1782cfa1..57130be8 100644 --- a/llvm_passes/HardwareFailureAutoScanPass.cpp +++ b/llvm_passes/HardwareFailureAutoScanPass.cpp @@ -1,81 +1,101 @@ -#define DEBUG_TYPE "HardwareFailureAutoScanPass" - #include "FICustomSelectorManager.h" -#include "Utils.h" -#include "FIInstSelectorManager.h" #include "FIInstSelector.h" -#include "InstTypeFIInstSelector.h" -#include "FuncNameFIInstSelector.h" +#include "FIInstSelectorManager.h" #include "FIRegSelector.h" +#include "InstTypeFIInstSelector.h" #include "RegLocBasedFIRegSelector.h" +#include "Utils.h" -#include "llvm/Pass.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" #include #include +#include "FuncNameFIInstSelector.h" + using namespace llvm; -namespace llfi{ - static cl::opt< std::string > outputpath("hardwarescan_outputfilename", - cl::desc("The path to store a list of applicable software failures"), - cl::init("llfi.applicable.hardware.selectors.txt")); +namespace llfi { +static cl::opt outputpath( + "hardwarescan_outputfilename", + cl::desc("The path to store a list of applicable software failures"), + cl::init("llfi.applicable.hardware.selectors.txt")); + +class HardwareFailureAutoScanPass : public ModulePass { +private: + std::ofstream selector_record_file; + +public: + static char ID; + HardwareFailureAutoScanPass() : ModulePass(ID) {} + bool runOnModule(Module &M) override { + selector_record_file.open(outputpath.getValue().c_str(), + std::ofstream::out); - class HardwareFailureAutoScanPass: public ModulePass{ - private: - std::ofstream selector_record_file; - public: - static char ID; - HardwareFailureAutoScanPass():ModulePass(ID){} - virtual bool runOnModule(Module &M){ - selector_record_file.open(std::string(outputpath).c_str(), std::ofstream::out); + FICustomInstSelectorManager *im = + FICustomInstSelectorManager::getCustomInstSelectorManager(); + FICustomRegSelectorManager *rm = + FICustomRegSelectorManager::getCustomRegSelectorManager(); + std::set all_hardware_inst_selector_names; + im->getAllHardwareSelectors(all_hardware_inst_selector_names); + all_hardware_inst_selector_names.insert(std::string("insttype")); + all_hardware_inst_selector_names.insert(std::string("funcname")); - FICustomInstSelectorManager *im = FICustomInstSelectorManager::getCustomInstSelectorManager(); - FICustomRegSelectorManager *rm = FICustomRegSelectorManager::getCustomRegSelectorManager(); - std::set all_hardware_inst_selector_names; - im->getAllHardwareSelectors(all_hardware_inst_selector_names); - all_hardware_inst_selector_names.insert(std::string("insttype")); - all_hardware_inst_selector_names.insert(std::string("funcname")); + std::set all_hardware_reg_selector_names; + rm->getAllHardwareSelectors(all_hardware_reg_selector_names); + all_hardware_reg_selector_names.insert(std::string("regloc")); + // errs()<<"get all soft failures\n"; - std::set all_hardware_reg_selector_names; - rm->getAllHardwareSelectors(all_hardware_reg_selector_names); - all_hardware_reg_selector_names.insert(std::string("regloc")); - // errs()<<"get all soft failures\n"; + recordString(std::string("instSelMethod:")); + for (std::set::iterator name = + all_hardware_inst_selector_names.begin(); + name != all_hardware_inst_selector_names.end(); name++) { + { + recordString(std::string(" - ") + *name); + } + } - recordString(std::string("instSelMethod:")); - for(std::set::iterator name = all_hardware_inst_selector_names.begin(); - name != all_hardware_inst_selector_names.end(); name++){ - { - recordString(std::string(" - ") + *name); - } - } + recordString(std::string("regSelMethod:")); + for (std::set::iterator name = + all_hardware_reg_selector_names.begin(); + name != all_hardware_reg_selector_names.end(); name++) { + { + recordString(std::string(" - ") + *name); + } + } + selector_record_file.close(); + return false; + } - recordString(std::string("regSelMethod:")); - for(std::set::iterator name = all_hardware_reg_selector_names.begin(); - name != all_hardware_reg_selector_names.end(); name++){ - { - recordString(std::string(" - ") + *name); - } - } - selector_record_file.close(); - } + void recordString(std::string str) { + if (selector_record_file.is_open() == false) { + errs() << "ERROR: can not open file to record applicable selectors: " + << outputpath << "\n"; + selector_record_file.close(); + return; + } + selector_record_file << str << "\n"; + return; + } +}; +char HardwareFailureAutoScanPass::ID = 0; +static RegisterPass + X("HardwareFailureAutoScanPass", + "Automatic scanner of hardware failure modes (instruction selectors, reg " + "selectors)", + false, false); +} // namespace llfi - void recordString(std::string str){ - if(selector_record_file.is_open() == false){ - std::cerr<<"ERROR: can not open file to record applicable selectors: "; - std::cerr< - X("HardwareFailureAutoScanPass", "Automatic scanner of hardware failure modes (instruction selectors, reg selectors)", - false, false); -} \ No newline at end of file +// Free function callable from RegisterPasses.cpp for the new PM wrapper. +namespace llfi { +void runHardwareFailureAutoScan(llvm::Module &M) { + HardwareFailureAutoScanPass().runOnModule(M); +} +} // namespace llfi \ No newline at end of file diff --git a/llvm_passes/MainGraphInstSelector.cpp b/llvm_passes/MainGraphInstSelector.cpp index d148f647..6ecacb95 100644 --- a/llvm_passes/MainGraphInstSelector.cpp +++ b/llvm_passes/MainGraphInstSelector.cpp @@ -1,8 +1,8 @@ -#include "llvm/IR/Instructions.h" - #include "FICustomSelectorManager.h" #include "FIInstSelector.h" +#include "llvm/IR/Instructions.h" + using namespace llvm; namespace llfi { @@ -15,7 +15,7 @@ namespace llfi { // config file class MainGraphInstSelector : public HardwareFIInstSelector { private: - virtual bool isInstFITarget(Instruction *inst) { + bool isInstFITarget(Instruction *inst) override { if (inst->getParent()->getParent()->getName() == "main_graph") { if (inst->getOpcode() == Instruction::FAdd || inst->getOpcode() == Instruction::FMul || @@ -27,7 +27,7 @@ class MainGraphInstSelector : public HardwareFIInstSelector { } public: - virtual void getCompileTimeInfo(std::map &info) { + void getCompileTimeInfo(std::map &info) override { info["failure_class"] = "HardwareFault"; info["failure_mode"] = "MainGraph"; info["targets"] = ""; diff --git a/llvm_passes/RegisterPasses.cpp b/llvm_passes/RegisterPasses.cpp index c6f4620e..fef8bd3e 100644 --- a/llvm_passes/RegisterPasses.cpp +++ b/llvm_passes/RegisterPasses.cpp @@ -1,36 +1,47 @@ #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/PassPlugin.h" -#include "core/ProfilingPass.h" -#include "core/GenLLFIIndexPass.h" #include "core/FaultInjectionPass.h" -#include "core/LLFIDotGraphPass.h" +#include "core/GenLLFIIndexPass.h" #include "core/InstTracePass.h" +#include "core/LLFIDotGraphPass.h" +#include "core/ProfilingPass.h" using namespace llvm; namespace llfi { +// Forward declarations for auto-scan free functions defined in their +// respective .cpp translation units. +void runHardwareFailureAutoScan(llvm::Module &M); - //----------------------------------------------------------------------------- - // New PM Registration - //----------------------------------------------------------------------------- - llvm::PassPluginLibraryInfo getLLFIPassPluginInfo() { - return {LLVM_PLUGIN_API_VERSION, "llfi_passes", LLVM_VERSION_STRING, - [](PassBuilder &PB) { +struct NewHardwareFailureAutoScanPass + : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + runHardwareFailureAutoScan(M); + return PreservedAnalyses::all(); + } + static bool isRequired() { return true; } +}; - // For GenLLFIIndexPass - PB.registerPipelineParsingCallback( - [](StringRef Name, ModulePassManager &MPM, - ArrayRef) { - if (Name == "genllfiindexpass") { - MPM.addPass(llfi::GenLLFIIndexPass()); - return true; - } - return false; - }); +//----------------------------------------------------------------------------- +// New PM Registration +//----------------------------------------------------------------------------- +llvm::PassPluginLibraryInfo getLLFIPassPluginInfo() { + return {LLVM_PLUGIN_API_VERSION, "llfi_passes", LLVM_VERSION_STRING, + [](PassBuilder &PB) { + // For GenLLFIIndexPass + PB.registerPipelineParsingCallback( + [](StringRef Name, ModulePassManager &MPM, + ArrayRef) { + if (Name == "genllfiindexpass") { + MPM.addPass(llfi::GenLLFIIndexPass()); + return true; + } + return false; + }); - // For ProfilingPass - PB.registerPipelineParsingCallback( + // For ProfilingPass + PB.registerPipelineParsingCallback( [](StringRef Name, ModulePassManager &MPM, ArrayRef) { if (Name == "profilingpass") { @@ -40,8 +51,8 @@ namespace llfi { return false; }); - // For FaultInjectionPass - PB.registerPipelineParsingCallback( + // For FaultInjectionPass + PB.registerPipelineParsingCallback( [](StringRef Name, ModulePassManager &MPM, ArrayRef) { if (Name == "faultinjectionpass") { @@ -51,8 +62,8 @@ namespace llfi { return false; }); - // For DotGraphPass - PB.registerPipelineParsingCallback( + // For DotGraphPass + PB.registerPipelineParsingCallback( [](StringRef Name, ModulePassManager &MPM, ArrayRef) { if (Name == "dotgraphpass") { @@ -62,8 +73,8 @@ namespace llfi { return false; }); - // For InstructionTracePass - PB.registerPipelineParsingCallback( + // For InstructionTracePass + PB.registerPipelineParsingCallback( [](StringRef Name, ModulePassManager &MPM, ArrayRef) { if (Name == "insttracepass") { @@ -72,11 +83,22 @@ namespace llfi { } return false; }); - }}; - } - extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo - llvmGetPassPluginInfo() { - return getLLFIPassPluginInfo(); - } + // For HardwareFailureAutoScanPass + PB.registerPipelineParsingCallback( + [](StringRef Name, ModulePassManager &MPM, + ArrayRef) { + if (Name == "HardwareFailureAutoScanPass") { + MPM.addPass(llfi::NewHardwareFailureAutoScanPass()); + return true; + } + return false; + }); + }}; +} + +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo +llvmGetPassPluginInfo() { + return getLLFIPassPluginInfo(); } +} // namespace llfi diff --git a/llvm_passes/SampleFIInstSelector.cpp b/llvm_passes/SampleFIInstSelector.cpp index 457276c9..b5e82757 100644 --- a/llvm_passes/SampleFIInstSelector.cpp +++ b/llvm_passes/SampleFIInstSelector.cpp @@ -1,7 +1,7 @@ -#include "llvm/IR/Instructions.h" - -#include "FIInstSelector.h" #include "FICustomSelectorManager.h" +#include "FIInstSelector.h" + +#include "llvm/IR/Instructions.h" using namespace llvm; @@ -12,16 +12,17 @@ namespace llfi { */ // TODO: enable custom selctor to have more sources of options, e.g. read from // config file -class SampleFIInstSelector: public HardwareFIInstSelector { - private: - virtual bool isInstFITarget(Instruction *inst) { +class SampleFIInstSelector : public HardwareFIInstSelector { +private: + bool isInstFITarget(Instruction *inst) override { if (inst->getParent()->getParent()->getName() == "main") return true; else return false; } - public: - virtual void getCompileTimeInfo(std::map& info){ + +public: + void getCompileTimeInfo(std::map &info) override { info["failure_class"] = "HardwareFault"; info["failure_mode"] = "OnlyMain"; info["targets"] = ""; @@ -30,4 +31,4 @@ class SampleFIInstSelector: public HardwareFIInstSelector { }; static RegisterFIInstSelector X("onlymain", new SampleFIInstSelector()); -} +} // namespace llfi diff --git a/llvm_passes/SampleFIRegSelector.cpp b/llvm_passes/SampleFIRegSelector.cpp index 84265c01..5620e107 100644 --- a/llvm_passes/SampleFIRegSelector.cpp +++ b/llvm_passes/SampleFIRegSelector.cpp @@ -1,18 +1,18 @@ -#include "llvm/IR/Value.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Constants.h" - -#include "FIRegSelector.h" #include "FICustomSelectorManager.h" +#include "FIRegSelector.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Value.h" namespace llfi { /** * This sample register selector only selects constant int as target */ -class SampleFIRegSelector: public HardwareFIRegSelector { - private: - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst) { +class SampleFIRegSelector : public HardwareFIRegSelector { +private: + bool isRegofInstFITarget(Value *reg, Instruction *inst) override { if (isa(reg)) return true; else @@ -22,4 +22,4 @@ class SampleFIRegSelector: public HardwareFIRegSelector { static RegisterFIRegSelector X("onlyconstint", new SampleFIRegSelector()); -} +} // namespace llfi diff --git a/llvm_passes/SoftwareFailureAutoScanPass.cpp b/llvm_passes/SoftwareFailureAutoScanPass.cpp deleted file mode 100644 index f389fc19..00000000 --- a/llvm_passes/SoftwareFailureAutoScanPass.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#define DEBUG_TYPE "SoftwareFailureAutoScanPass" - -#include "FICustomSelectorManager.h" -#include "Utils.h" -#include "FIInstSelectorManager.h" -#include "FIInstSelector.h" -#include "InstTypeFIInstSelector.h" -#include "FuncNameFIInstSelector.h" -#include "FIRegSelector.h" -#include "RegLocBasedFIRegSelector.h" - -#include "llvm/Pass.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Support/CommandLine.h" - -#include -#include - -using namespace llvm; -namespace llfi{ - static cl::opt< std::string > outputpath("softwarescan_outputfilename", - cl::desc("The path to store a list of applicable software failures"), - cl::init("llfi.applicable.software.failures.txt")); - - class SoftwareFailureAutoScanPass: public ModulePass{ - private: - std::ofstream selector_record_file; - public: - static char ID; - SoftwareFailureAutoScanPass():ModulePass(ID){} - virtual bool runOnModule(Module &M){ - selector_record_file.open(std::string(outputpath).c_str(), std::ofstream::out); - selector_record_file<<"instSelMethod:"<<"\n"; - - FICustomInstSelectorManager *im = FICustomInstSelectorManager::getCustomInstSelectorManager(); - FICustomRegSelectorManager *rm = FICustomRegSelectorManager::getCustomRegSelectorManager(); - std::set all_software_failure_names; - im->getAllSoftwareSelectors(all_software_failure_names); - // errs()<<"get all soft failures\n"; - for(std::set::iterator name = all_software_failure_names.begin(); - name != all_software_failure_names.end(); name++){ - // errs()<<"# start on: "<<*name<<"\n"; - FIInstSelectorManager *fiinstselector = new FIInstSelectorManager; - fiinstselector->addSelector(im->getCustomInstSelector(*name)); - // errs()<<"# inst selector done on: "<<*name<<"\n"; - FIRegSelector* firegselector = rm->getCustomRegSelector(*name); - // errs()<<"# reg selector done on: "<<*name<<"\n"; - // select fault injection instructions - std::set fiinstset; - fiinstselector->getFIInsts(M, &fiinstset); - // errs()<<"# size of inst set: "<* > fi_inst_regs_map; - // select fault injection registers - firegselector->getFIInstRegMap(&fiinstset, &fi_inst_regs_map); - delete fiinstselector; - // errs()<<"# collection done on: "<<*name<<"\n"; - bool not_empty = false; - for(std::map* >::iterator MI = fi_inst_regs_map.begin(); - MI != fi_inst_regs_map.end(); MI++){ - if(MI->second->empty()) continue; - else not_empty = true; - } - if(not_empty == true){ - recordInstSelector(*name); - } - // errs()<<"# check done on: "<<*name<<"\n"; - for(std::map* >::iterator MI = fi_inst_regs_map.begin(); - MI != fi_inst_regs_map.end(); MI++){ - delete MI->second; - } - } - selector_record_file.close(); - } - - void recordInstSelector(std::string selector_name){ - if(selector_record_file.is_open() == false){ - std::cerr<<"ERROR: can not open file to record applicable selectors: "; - std::cerr< - X("SoftwareFailureAutoScanPass", "Automatic scanner of software failure modes", - false, false); -} \ No newline at end of file diff --git a/llvm_passes/core/Controller.cpp b/llvm_passes/core/Controller.cpp index b0fd23bd..dfd654a0 100644 --- a/llvm_passes/core/Controller.cpp +++ b/llvm_passes/core/Controller.cpp @@ -1,17 +1,19 @@ -#include "llvm/IR/Module.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" - #include "Controller.h" + #include "FICustomSelectorManager.h" -#include "Utils.h" -#include "FIInstSelectorManager.h" #include "FIInstSelector.h" -#include "InstTypeFIInstSelector.h" -#include "FuncNameFIInstSelector.h" +#include "FIInstSelectorManager.h" #include "FIRegSelector.h" +#include "InstTypeFIInstSelector.h" #include "RegLocBasedFIRegSelector.h" +#include "Utils.h" + +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "FuncNameFIInstSelector.h" using namespace llvm; @@ -29,34 +31,39 @@ static cl::list fiinstselmethod( cl::ZeroOrMore); // inst type -static cl::list< std::string > includeinst("includeinst", - cl::desc("The type of instruction to be included for fault injection"), +static cl::list includeinst( + "includeinst", + cl::desc("The type of instruction to be included for fault injection"), cl::ZeroOrMore); -static cl::list< std::string > excludeinst("excludeinst", - cl::desc("The type of instruction to be excluded for fault injection"), +static cl::list excludeinst( + "excludeinst", + cl::desc("The type of instruction to be excluded for fault injection"), cl::ZeroOrMore); // func name -static cl::list< std::string > includefunc("includefunc", - cl::desc("The function name to be included for fault injection"), +static cl::list includefunc( + "includefunc", + cl::desc("The function name to be included for fault injection"), cl::ZeroOrMore); -static cl::list< std::string > excludefunc("excludefunc", - cl::desc("The function name to be excluded for fault injection"), +static cl::list excludefunc( + "excludefunc", + cl::desc("The function name to be excluded for fault injection"), cl::ZeroOrMore); // custom instruction selector name -static cl::opt < std::string > fiinstselectorname("fiinstselectorname", +static cl::opt fiinstselectorname( + "fiinstselectorname", cl::desc("Custom fault injection instruction selector name")); // backtrace or forwardtrace included -static cl::opt< bool > includebackwardtrace("includebackwardtrace", - cl::init(false), - cl::desc( - "Include backward trace of the selected instructions for fault injection")); -static cl::opt< bool > includeforwardtrace("includeforwardtrace", - cl::init(false), - cl::desc( - "Include forward trace of the selected instructions for fault injection")); +static cl::opt + includebackwardtrace("includebackwardtrace", cl::init(false), + cl::desc("Include backward trace of the selected " + "instructions for fault injection")); +static cl::opt + includeforwardtrace("includeforwardtrace", cl::init(false), + cl::desc("Include forward trace of the selected " + "instructions for fault injection")); /** * Inject Register @@ -73,26 +80,27 @@ static cl::opt firegselmethod( static cl::opt fireglocation( cl::desc("Choose fault injection register location:"), cl::init(dstreg), cl::values(clEnumVal(dstreg, "Inject into destination register"), - clEnumVal(allreg, "Inject randomly into either destination register or one of the source registers"), - clEnumVal(allsrcreg, "Inject randomly into one of the source registers"), + clEnumVal(allreg, "Inject randomly into either destination " + "register or one of the source registers"), + clEnumVal(allsrcreg, + "Inject randomly into one of the source registers"), clEnumVal(srcreg1, "Inject into 1st source register"), clEnumVal(srcreg2, "Inject into 2nd source register"), clEnumVal(srcreg3, "Inject into 3rd source register"), clEnumVal(srcreg4, "Inject into 4th source register"))); -static cl::opt < std::string > firegselectorname("firegselectorname", +static cl::opt firegselectorname( + "firegselectorname", cl::desc("Custom fault injection register selector name")); /** * Log file */ -cl::opt < std::string > llfilogfile("llfilogfile", - cl::init("llfi.log.compilation.txt"), - cl::Hidden, - cl::desc("Name of compilation passes logging file")); - +cl::opt + llfilogfile("llfilogfile", cl::init("llfi.log.compilation.txt"), cl::Hidden, + cl::desc("Name of compilation passes logging file")); -Controller *Controller::ctrl = NULL; +Controller *Controller::ctrl = nullptr; void Controller::getOpcodeListofFIInsts(std::set *fi_opcode_set) { NameOpcodeMap fullnameopcodemap; @@ -103,8 +111,8 @@ void Controller::getOpcodeListofFIInsts(std::set *fi_opcode_set) { // TODO: make "all" a static string if (includeinst[i] == "all") { for (NameOpcodeMap::const_iterator it = fullnameopcodemap.begin(); - it != fullnameopcodemap.end(); ++it) { - fi_opcode_set->insert(it->second); + it != fullnameopcodemap.end(); ++it) { + fi_opcode_set->insert(it->second); } break; } else { @@ -113,7 +121,7 @@ void Controller::getOpcodeListofFIInsts(std::set *fi_opcode_set) { fi_opcode_set->insert(loc->second); } else { errs() << "ERROR: Invalid include instruction type: " << includeinst[i] - << "\n"; + << "\n"; exit(1); } } @@ -126,7 +134,7 @@ void Controller::getOpcodeListofFIInsts(std::set *fi_opcode_set) { fi_opcode_set->erase(loc->second); } else { errs() << "ERROR: Invalid exclude instruction type: " << excludeinst[i] - << "\n"; + << "\n"; exit(1); } } @@ -135,8 +143,8 @@ void Controller::getFuncList(std::set *fi_func_set) { std::set::iterator it; std::string func; for (size_t i = 0; i < includefunc.size(); ++i) { - if(includefunc[i] == "all") { - for(it = func_set.begin(); it != func_set.end(); ++it) { + if (includefunc[i] == "all") { + for (it = func_set.begin(); it != func_set.end(); ++it) { func = demangleFuncName(*it); fi_func_set->insert(func); } @@ -147,13 +155,13 @@ void Controller::getFuncList(std::set *fi_func_set) { } // exclude list - for(size_t i = 0; i < excludefunc.size(); ++i) { + for (size_t i = 0; i < excludefunc.size(); ++i) { it = fi_func_set->find(excludefunc[i]); - if(it != fi_func_set->end()) { + if (it != fi_func_set->end()) { fi_func_set->erase(it); } else { errs() << "ERROR: Invalid exclude function name: " << excludefunc[i] - << "\n"; + << "\n"; exit(1); } } @@ -161,29 +169,29 @@ void Controller::getFuncList(std::set *fi_func_set) { void Controller::processInstSelArgs() { fiinstselector = new FIInstSelectorManager(); - std::set *fi_opcode_set; - std::set *fi_func_set; - FICustomInstSelectorManager *m; - for(size_t i = 0; i < fiinstselmethod.size(); ++i) { - switch(fiinstselmethod[i]) { - case insttype: - fi_opcode_set = new std::set; - getOpcodeListofFIInsts(fi_opcode_set); - fiinstselector->addSelector(new InstTypeFIInstSelector(fi_opcode_set)); - break; - case funcname: - fi_func_set = new std::set; - getFuncList(fi_func_set); - fiinstselector->addSelector(new FuncNameFIInstSelector(fi_func_set)); - break; - case custominstselector: - m = FICustomInstSelectorManager::getCustomInstSelectorManager(); - fiinstselector->addSelector(m->getCustomInstSelector(fiinstselectorname)); - break; - default: - // TODO: handle the source code case - errs() << "ERROR: option not implemented yet\n"; - exit(4); + std::set *fi_opcode_set = nullptr; + std::set *fi_func_set = nullptr; + FICustomInstSelectorManager *m = nullptr; + for (size_t i = 0; i < fiinstselmethod.size(); ++i) { + switch (fiinstselmethod[i]) { + case insttype: + fi_opcode_set = new std::set; + getOpcodeListofFIInsts(fi_opcode_set); + fiinstselector->addSelector(new InstTypeFIInstSelector(fi_opcode_set)); + break; + case funcname: + fi_func_set = new std::set; + getFuncList(fi_func_set); + fiinstselector->addSelector(new FuncNameFIInstSelector(fi_func_set)); + break; + case custominstselector: + m = FICustomInstSelectorManager::getCustomInstSelectorManager(); + fiinstselector->addSelector(m->getCustomInstSelector(fiinstselectorname)); + break; + default: + // TODO: handle the source code case + errs() << "ERROR: option not implemented yet\n"; + exit(4); } } fiinstselector->setIncludeBackwardTrace(includebackwardtrace); @@ -191,11 +199,11 @@ void Controller::processInstSelArgs() { } void Controller::processRegSelArgs() { - firegselector = NULL; + firegselector = nullptr; if (firegselmethod == regloc) { firegselector = new RegLocBasedFIRegSelector(fireglocation); } else { - FICustomRegSelectorManager *m = + FICustomRegSelectorManager *m = FICustomRegSelectorManager::getCustomRegSelectorManager(); firegselector = m->getCustomRegSelector(firegselectorname); } @@ -209,7 +217,7 @@ void Controller::processCmdArgs() { logFile << "\n\nStart of a pass\n"; } else { errs() << "Unable to output logging information to file " << llfilogfile - << "\n"; + << "\n"; } logFile.close(); @@ -221,7 +229,7 @@ void Controller::processCmdArgs() { // compiling C++ due to name mangling. void Controller::getModuleFuncs(Module &M) { Module::iterator it; - for(it = M.begin(); it != M.end(); ++it) { + for (it = M.begin(); it != M.end(); ++it) { std::string func_name = it->getName().str(); std::string final_name = demangleFuncName(func_name); @@ -235,7 +243,7 @@ void Controller::init(Module &M) { processCmdArgs(); // select fault injection instructions - std::set fiinstset; + std::set fiinstset; fiinstselector->getFIInsts(M, &fiinstset); // select fault injection registers @@ -243,26 +251,28 @@ void Controller::init(Module &M) { } Controller::~Controller() { - delete ctrl; - ctrl = NULL; + ctrl = nullptr; } void Controller::dump() const { - for (std::map *>::const_iterator inst_it = - fi_inst_regs_map.begin(); inst_it != fi_inst_regs_map.end(); ++inst_it) { + for (std::map *>::const_iterator inst_it = + fi_inst_regs_map.begin(); + inst_it != fi_inst_regs_map.end(); ++inst_it) { errs() << "Selected instruction " << *(inst_it->first) << "\nRegs:\n"; for (std::list::const_iterator reg_it = inst_it->second->begin(); reg_it != inst_it->second->end(); ++reg_it) { - if(*reg_it == DST_REG_POS) errs() << "\t" << *(inst_it->first) << "\n"; - else errs() << "\t" << inst_it->first->getOperand(*reg_it) << "\n"; + if (*reg_it == DST_REG_POS) + errs() << "\t" << *(inst_it->first) << "\n"; + else + errs() << "\t" << inst_it->first->getOperand(*reg_it) << "\n"; } errs() << "\n"; } } Controller *Controller::getInstance(Module &M) { - if (ctrl == NULL) + if (ctrl == nullptr) ctrl = new Controller(M); return ctrl; } -} +} // namespace llfi diff --git a/llvm_passes/core/Controller.h b/llvm_passes/core/Controller.h index ed7c6f28..c4f29519 100644 --- a/llvm_passes/core/Controller.h +++ b/llvm_passes/core/Controller.h @@ -1,14 +1,14 @@ -#ifndef CONFIG_H -#define CONFIG_H +#ifndef CONTROLLER_H +#define CONTROLLER_H #define LLVM_ON_UNIX 1 #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" -#include -#include #include +#include +#include #include #define DST_REG_POS -1 @@ -19,37 +19,28 @@ namespace llfi { class FIInstSelectorManager; class FIRegSelector; +enum FIInstSelMethod { insttype, funcname, sourcecode, custominstselector }; -enum FIInstSelMethod { - insttype, funcname, sourcecode, custominstselector -}; - -enum FIRegSelMethod { - regloc, customregselector -}; +enum FIRegSelMethod { regloc, customregselector }; -enum FIRegLoc { - dstreg, allreg, allsrcreg, srcreg1, srcreg2, srcreg3, srcreg4 -}; +enum FIRegLoc { dstreg, allreg, allsrcreg, srcreg1, srcreg2, srcreg3, srcreg4 }; class Controller { typedef std::map NameOpcodeMap; - public: + +public: static Controller *getInstance(Module &M); - ~Controller(); + ~Controller(); - public: - void getFIInstRegsMap( - std::map* > **fiinstreg) { +public: + void getFIInstRegsMap(std::map *> **fiinstreg) { *fiinstreg = &fi_inst_regs_map; } void dump() const; - private: +private: Controller() {} - Controller(Module &M) { - init(M); - } + Controller(Module &M) { init(M); } void init(Module &M); void processCmdArgs(); void processInstSelArgs(); @@ -59,24 +50,24 @@ class Controller { void getFuncList(std::set *fi_func_set); void getModuleFuncs(Module &M); - // output of the controller - private: + // output of the controller +private: // a map of target instructions and a list of inject loc // Assumption: changes on instructions do not have temporal relations // That's why we can use unordered map instead of list // TODO: replace tree-based map to hashtable-based map for performance - std::map* > fi_inst_regs_map; + std::map *> fi_inst_regs_map; - private: +private: FIInstSelectorManager *fiinstselector; FIRegSelector *firegselector; // set of functions present in module std::set func_set; - private: +private: static Controller *ctrl; }; -} +} // namespace llfi #endif diff --git a/llvm_passes/core/FICustomSelectorManager.cpp b/llvm_passes/core/FICustomSelectorManager.cpp index 9aa32ea5..71ba5d6b 100644 --- a/llvm_passes/core/FICustomSelectorManager.cpp +++ b/llvm_passes/core/FICustomSelectorManager.cpp @@ -1,16 +1,17 @@ -#include "llvm/Support/raw_ostream.h" +#include "FICustomSelectorManager.h" #include "FIInstSelector.h" #include "FIRegSelector.h" -#include "FICustomSelectorManager.h" + +#include "llvm/Support/raw_ostream.h" namespace llfi { // fault injection instruction selector manager -FICustomInstSelectorManager - *FICustomInstSelectorManager::getCustomInstSelectorManager() { +FICustomInstSelectorManager * +FICustomInstSelectorManager::getCustomInstSelectorManager() { static FICustomInstSelectorManager instsel_manager; - return &instsel_manager; + return &instsel_manager; } void FICustomInstSelectorManager::addCustomInstSelector( @@ -19,50 +20,37 @@ void FICustomInstSelectorManager::addCustomInstSelector( optionname_instselector[name] = instselector; } else { errs() << "ERROR: Duplicate custom fault injection instruction selector: " - << name << "\n"; + << name << "\n"; exit(1); } } -void FICustomInstSelectorManager::getAllSoftwareSelectors( - std::set& all_software_failure_names){ - for(std::map::iterator it = - optionname_instselector.begin(); it != optionname_instselector.end(); - ++it) { - if(it->second->getInstSelectorClass() == std::string("SoftwareFault")){ - all_software_failure_names.insert(it->first); - } - } - return; -} - void FICustomInstSelectorManager::getAllHardwareSelectors( - std::set& all_hardware_failure_names){ - for(std::map::iterator it = - optionname_instselector.begin(); it != optionname_instselector.end(); - ++it) { - if(it->second->getInstSelectorClass() == std::string("HardwareFault")){ - all_hardware_failure_names.insert(it->first); - } + std::set &all_hardware_failure_names) { + for (std::map::iterator it = + optionname_instselector.begin(); + it != optionname_instselector.end(); ++it) { + if (it->second->getInstSelectorClass() == std::string("HardwareFault")) { + all_hardware_failure_names.insert(it->first); + } } return; } -FIInstSelector *FICustomInstSelectorManager::getCustomInstSelector( - const std::string &name) { +FIInstSelector * +FICustomInstSelectorManager::getCustomInstSelector(const std::string &name) { if (optionname_instselector.find(name) != optionname_instselector.end()) { return optionname_instselector[name]; } else { errs() << "ERROR: Unknown custom fault injection instruction selector: " - << name << "\n"; + << name << "\n"; exit(1); } } - // fault injection register selector manager -FICustomRegSelectorManager - *FICustomRegSelectorManager::getCustomRegSelectorManager() { +FICustomRegSelectorManager * +FICustomRegSelectorManager::getCustomRegSelectorManager() { static FICustomRegSelectorManager regsel_manager; return ®sel_manager; } @@ -73,44 +61,32 @@ void FICustomRegSelectorManager::addCustomRegSelector( optionname_regselector[name] = regselector; } else { errs() << "ERROR: Duplicate custom fault injection register selector: " - << name << "\n"; + << name << "\n"; exit(1); } } -FIRegSelector *FICustomRegSelectorManager::getCustomRegSelector( - const std::string &name) { +FIRegSelector * +FICustomRegSelectorManager::getCustomRegSelector(const std::string &name) { if (optionname_regselector.find(name) != optionname_regselector.end()) { return optionname_regselector[name]; } else { - errs() << "ERROR: Unknown custom fault injection register selector: " << - name << "\n"; + errs() << "ERROR: Unknown custom fault injection register selector: " + << name << "\n"; exit(1); } } -void FICustomRegSelectorManager::getAllSoftwareSelectors( - std::set& all_software_failure_names){ - for(std::map::iterator it = - optionname_regselector.begin(); it != optionname_regselector.end(); - ++it) { - if(it->second->getRegSelectorClass() == std::string("SoftwareFault")){ - all_software_failure_names.insert(it->first); - } - } - return; -} - void FICustomRegSelectorManager::getAllHardwareSelectors( - std::set& all_hardware_failure_names){ - for(std::map::iterator it = - optionname_regselector.begin(); it != optionname_regselector.end(); - ++it) { - if(it->second->getRegSelectorClass() == std::string("HardwareFault")){ - all_hardware_failure_names.insert(it->first); - } + std::set &all_hardware_failure_names) { + for (std::map::iterator it = + optionname_regselector.begin(); + it != optionname_regselector.end(); ++it) { + if (it->second->getRegSelectorClass() == std::string("HardwareFault")) { + all_hardware_failure_names.insert(it->first); + } } return; } -} +} // namespace llfi diff --git a/llvm_passes/core/FICustomSelectorManager.h b/llvm_passes/core/FICustomSelectorManager.h index 6a95c669..6e4a1cd9 100644 --- a/llvm_passes/core/FICustomSelectorManager.h +++ b/llvm_passes/core/FICustomSelectorManager.h @@ -3,51 +3,49 @@ #include #include -#include #include +#include namespace llfi { class FIInstSelector; class FIRegSelector; class FICustomInstSelectorManager { - public: +public: FICustomInstSelectorManager() {} - public: +public: static FICustomInstSelectorManager *getCustomInstSelectorManager(); - void addCustomInstSelector(const std::string &name, + void addCustomInstSelector(const std::string &name, FIInstSelector *instselector); FIInstSelector *getCustomInstSelector(const std::string &name); - void getAllSoftwareSelectors(std::set& all_software_failure_names); - void getAllHardwareSelectors(std::set& all_hardware_failure_names); + void + getAllHardwareSelectors(std::set &all_hardware_failure_names); - private: - std::map optionname_instselector; +private: + std::map optionname_instselector; }; - class FICustomRegSelectorManager { - public: +public: FICustomRegSelectorManager() {} - public: +public: static FICustomRegSelectorManager *getCustomRegSelectorManager(); - void addCustomRegSelector(const std::string &name, + void addCustomRegSelector(const std::string &name, FIRegSelector *regselector); FIRegSelector *getCustomRegSelector(const std::string &name); - void getAllSoftwareSelectors(std::set& all_software_failure_names); - void getAllHardwareSelectors(std::set& all_hardware_failure_names); - - private: - std::map optionname_regselector; -}; + void + getAllHardwareSelectors(std::set &all_hardware_failure_names); +private: + std::map optionname_regselector; +}; // helper class to register custom inst or reg selector struct RegisterFIInstSelector { RegisterFIInstSelector(const std::string &name, FIInstSelector *sel) { - FICustomInstSelectorManager *m = + FICustomInstSelectorManager *m = FICustomInstSelectorManager::getCustomInstSelectorManager(); m->addCustomInstSelector(name, sel); } @@ -61,5 +59,5 @@ struct RegisterFIRegSelector { } }; -} +} // namespace llfi #endif diff --git a/llvm_passes/core/FIInstSelector.cpp b/llvm_passes/core/FIInstSelector.cpp index 4fc0ec12..5fa72bc6 100644 --- a/llvm_passes/core/FIInstSelector.cpp +++ b/llvm_passes/core/FIInstSelector.cpp @@ -1,14 +1,14 @@ +#include "FIInstSelector.h" + #include "llvm/IR/InstIterator.h" #include "llvm/Support/raw_ostream.h" -#include "FIInstSelector.h" - namespace llfi { -void FIInstSelector::getFIInsts(Module &M, std::set *fiinsts) { +void FIInstSelector::getFIInsts(Module &M, std::set *fiinsts) { getInitFIInsts(M, fiinsts); - std::set bs; - std::set fs; + std::set bs; + std::set fs; // must do both of the computation on the fiinsts, and update // fiinsts finally if (includebackwardtrace) @@ -20,8 +20,8 @@ void FIInstSelector::getFIInsts(Module &M, std::set *fiinsts) { fiinsts->insert(fs.begin(), fs.end()); } -void FIInstSelector::getInitFIInsts(Module &M, - std::set *fiinsts) { +void FIInstSelector::getInitFIInsts(Module &M, + std::set *fiinsts) { for (Module::iterator m_it = M.begin(); m_it != M.end(); ++m_it) { if (!m_it->isDeclaration()) { // m_it is a function @@ -32,13 +32,13 @@ void FIInstSelector::getInitFIInsts(Module &M, fiinsts->insert(inst); } } - } + } } } void FIInstSelector::getBackwardTraceofInsts( - const std::set *fiinsts, std::set *bs) { - for (std::set::const_iterator inst_it = fiinsts->begin(); + const std::set *fiinsts, std::set *bs) { + for (std::set::const_iterator inst_it = fiinsts->begin(); inst_it != fiinsts->end(); ++inst_it) { Instruction *inst = *inst_it; getBackwardTraceofInst(inst, bs); @@ -46,8 +46,8 @@ void FIInstSelector::getBackwardTraceofInsts( } void FIInstSelector::getForwardTraceofInsts( - const std::set *fiinsts, std::set *fs) { - for (std::set::const_iterator inst_it = fiinsts->begin(); + const std::set *fiinsts, std::set *fs) { + for (std::set::const_iterator inst_it = fiinsts->begin(); inst_it != fiinsts->end(); ++inst_it) { Instruction *inst = *inst_it; getForwardTraceofInst(inst, fs); @@ -55,9 +55,9 @@ void FIInstSelector::getForwardTraceofInsts( } void FIInstSelector::getBackwardTraceofInst(Instruction *inst, - std::set *bs) { - for (User::op_iterator op_it = inst->op_begin(); - op_it != inst->op_end(); ++op_it) { + std::set *bs) { + for (User::op_iterator op_it = inst->op_begin(); op_it != inst->op_end(); + ++op_it) { Value *src = *op_it; if (Instruction *src_inst = dyn_cast(src)) { if (bs->find(src_inst) == bs->end()) { @@ -69,7 +69,7 @@ void FIInstSelector::getBackwardTraceofInst(Instruction *inst, } void FIInstSelector::getForwardTraceofInst(Instruction *inst, - std::set *fs) { + std::set *fs) { for (Value::user_iterator user_it = inst->user_begin(); user_it != inst->user_end(); ++user_it) { User *user = *user_it; @@ -82,11 +82,12 @@ void FIInstSelector::getForwardTraceofInst(Instruction *inst, } } -void FIInstSelector::getCompileTimeInfo(std::map& info) { +void FIInstSelector::getCompileTimeInfo( + std::map &info) { info["failure_class"] = "Unknown"; info["failure_mode"] = "Unknown"; info["targets"] = "Unknown"; info["injector"] = "Unknown"; } -} +} // namespace llfi diff --git a/llvm_passes/core/FIInstSelector.h b/llvm_passes/core/FIInstSelector.h index e57cfcbe..0bfc7865 100644 --- a/llvm_passes/core/FIInstSelector.h +++ b/llvm_passes/core/FIInstSelector.h @@ -1,67 +1,59 @@ #ifndef FI_INST_SELECTOR_H #define FI_INST_SELECTOR_H -#include "llvm/IR/Module.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" -#include #include +#include using namespace llvm; namespace llfi { class FIInstSelector { - public: - FIInstSelector(): includebackwardtrace(false), includeforwardtrace(false) {} +public: + FIInstSelector() : includebackwardtrace(false), includeforwardtrace(false) {} + virtual ~FIInstSelector() = default; - public: - void getFIInsts(Module &M, std::set *fiinsts); - virtual void getCompileTimeInfo(std::map& info); +public: + void getFIInsts(Module &M, std::set *fiinsts); + virtual void getCompileTimeInfo(std::map &info); - virtual std::string getInstSelectorClass(){ - return std::string("Unknown"); - } + virtual std::string getInstSelectorClass() { return std::string("Unknown"); } - public: +public: inline void setIncludeBackwardTrace(bool includebt) { includebackwardtrace = includebt; } inline void setIncludeForwardTrace(bool includeft) { includeforwardtrace = includeft; } - - private: + +private: // get the initial fault injection instruction without backtrace or forward // trace, selection from source code may need to rewrite this function - virtual void getInitFIInsts(Module &M, std::set *fiinsts); + virtual void getInitFIInsts(Module &M, std::set *fiinsts); - virtual bool isInstFITarget(Instruction* inst) = 0; + virtual bool isInstFITarget(Instruction *inst) = 0; - protected: +protected: // only get the "instructions" that are the backward/forward trace of inst - void getBackwardTraceofInsts(const std::set *fiinsts, - std::set *bs); - void getForwardTraceofInsts(const std::set *fiinsts, - std::set *fs); - void getBackwardTraceofInst(Instruction *inst, - std::set *bs); - void getForwardTraceofInst(Instruction *inst, - std::set *fs); - protected: + void getBackwardTraceofInsts(const std::set *fiinsts, + std::set *bs); + void getForwardTraceofInsts(const std::set *fiinsts, + std::set *fs); + void getBackwardTraceofInst(Instruction *inst, std::set *bs); + void getForwardTraceofInst(Instruction *inst, std::set *fs); + +protected: bool includebackwardtrace; bool includeforwardtrace; -}; - -class SoftwareFIInstSelector: public FIInstSelector{ - virtual std::string getInstSelectorClass(){ - return std::string("SoftwareFault"); - } }; -class HardwareFIInstSelector: public FIInstSelector{ - virtual std::string getInstSelectorClass(){ +class HardwareFIInstSelector : public FIInstSelector { + std::string getInstSelectorClass() override { return std::string("HardwareFault"); } }; -} +} // namespace llfi #endif diff --git a/llvm_passes/core/FIInstSelectorManager.cpp b/llvm_passes/core/FIInstSelectorManager.cpp index 0488ba74..74fb4b4f 100644 --- a/llvm_passes/core/FIInstSelectorManager.cpp +++ b/llvm_passes/core/FIInstSelectorManager.cpp @@ -1,84 +1,78 @@ -#include "llvm/IR/Instructions.h" - #include "FIInstSelectorManager.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/raw_ostream.h" + namespace llfi { void FIInstSelectorManager::getFIInsts(Module &M, - std::set *fiinsts) { + std::set *fiinsts) { // Create a set for each selector and print compiletime info - std::vector*> allInsts; - for(it = selectors.begin(); it != selectors.end(); ++it) { - allInsts.push_back(new std::set); - std::map info; - (*it)->getCompileTimeInfo(info); - printCompileTimeInfo(info); - (*it)->getFIInsts(M, allInsts.back()); + std::vector *> allInsts; + for (it = selectors.begin(); it != selectors.end(); ++it) { + allInsts.push_back(new std::set); + std::map info; + (*it)->getCompileTimeInfo(info); + printCompileTimeInfo(info); + (*it)->getFIInsts(M, allInsts.back()); } // Merge allInsts into fiinsts - std::set merge = *(allInsts[0]); - for(size_t i = 1; i < allInsts.size(); ++i) { - std::set tmp; + std::set merge = *(allInsts[0]); + for (size_t i = 1; i < allInsts.size(); ++i) { + std::set tmp; tmp.swap(merge); - std::set_intersection(tmp.begin(), tmp.end(), - allInsts[i]->begin(), allInsts[i]->end(), + std::set_intersection(tmp.begin(), tmp.end(), allInsts[i]->begin(), + allInsts[i]->end(), std::inserter(merge, merge.begin())); } fiinsts->swap(merge); - for(size_t i = 0; i < allInsts.size(); ++i) { + for (size_t i = 0; i < allInsts.size(); ++i) { delete allInsts[i]; } } -int FIInstSelectorManager::printCompileTimeInfo(std::map& info) { +int FIInstSelectorManager::printCompileTimeInfo( + std::map &info) { // print compiletime info returned from inst selector, called by getFIInsts() std::ofstream compiletimeinfo_file("llfi.config.compiletime.txt"); - if(compiletimeinfo_file.is_open() == false){ - std::cerr<<"ERROR: can not open llfi.config.compiletime.txt\n"; + if (compiletimeinfo_file.is_open() == false) { + errs() << "ERROR: can not open llfi.config.compiletime.txt\n"; compiletimeinfo_file.close(); return -1; } - compiletimeinfo_file<<"failure_class="<setIncludeBackwardTrace(includebt); } - } -void FIInstSelectorManager::setIncludeForwardTrace(bool includeft) -{ - for(size_t i = 0; i < selectors.size(); ++i) { +void FIInstSelectorManager::setIncludeForwardTrace(bool includeft) { + for (size_t i = 0; i < selectors.size(); ++i) { selectors[i]->setIncludeForwardTrace(includeft); } } -FIInstSelectorManager::FIInstSelectorManager() -{ - -} +FIInstSelectorManager::FIInstSelectorManager() {} -FIInstSelectorManager::~FIInstSelectorManager() -{ - for(it = selectors.begin(); it != selectors.end(); ++it) { +FIInstSelectorManager::~FIInstSelectorManager() { + for (it = selectors.begin(); it != selectors.end(); ++it) { delete *it; } } -} +} // namespace llfi diff --git a/llvm_passes/core/FIInstSelectorManager.h b/llvm_passes/core/FIInstSelectorManager.h index 6c438d97..57640a62 100644 --- a/llvm_passes/core/FIInstSelectorManager.h +++ b/llvm_passes/core/FIInstSelectorManager.h @@ -1,32 +1,32 @@ #ifndef FI_INST_SELECTOR_MANAGER_H #define FI_INST_SELECTOR_MANAGER_H -#include -#include +#include "FIInstSelector.h" + #include #include - -#include "FIInstSelector.h" +#include +#include using namespace llvm; namespace llfi { class FIInstSelectorManager { - public: +public: FIInstSelectorManager(); ~FIInstSelectorManager(); void addSelector(FIInstSelector *s); - void getFIInsts(Module &M, std::set *fiinsts); + void getFIInsts(Module &M, std::set *fiinsts); void setIncludeBackwardTrace(bool includebt); void setIncludeForwardTrace(bool includeft); - private: - std::vector selectors; - std::vector::iterator it; +private: + std::vector selectors; + std::vector::iterator it; - int printCompileTimeInfo(std::map& info); + int printCompileTimeInfo(std::map &info); }; -} +} // namespace llfi #endif diff --git a/llvm_passes/core/FIRegSelector.cpp b/llvm_passes/core/FIRegSelector.cpp index bf96be5a..6a6313ce 100644 --- a/llvm_passes/core/FIRegSelector.cpp +++ b/llvm_passes/core/FIRegSelector.cpp @@ -1,36 +1,36 @@ +#include "FIRegSelector.h" + #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "FIRegSelector.h" - using namespace llvm; namespace llfi { -extern cl::opt< std::string > llfilogfile; +extern cl::opt llfilogfile; void FIRegSelector::getFIInstRegMap( - const std::set< Instruction* > *instset, - std::map* > *instregmap) { + const std::set *instset, + std::map *> *instregmap) { std::error_code err; raw_fd_ostream logFile(llfilogfile.c_str(), err, sys::fs::OF_Append); - for (std::set::const_iterator inst_it = instset->begin(); + for (std::set::const_iterator inst_it = instset->begin(); inst_it != instset->end(); ++inst_it) { Instruction *inst = *inst_it; std::list *reglist = new std::list(); - + // destination register if (isRegofInstFITarget(inst, inst)) { if (isRegofInstInjectable(inst, inst)) { - // dbgs() << "dstreg " << " inst: "<< *inst<<"\n"; + // dbgs() << "dstreg " << " inst: "<< *inst<<"\n"; reglist->push_back(DST_REG_POS); } else if (!err) { logFile << "LLFI cannot inject faults in destination reg of " << *inst - << "\n"; + << "\n"; } } // source register @@ -41,10 +41,11 @@ void FIRegSelector::getFIInstRegMap( if (isRegofInstFITarget(src, inst, pos)) { if (isRegofInstInjectable(src, inst)) { reglist->push_back(pos); - // dbgs()<<"srcreg "<<" inst:"<<*inst<<" reg:"<<*inst->getOperand(pos)<<" pos:"<getOperand(pos)<<" pos:"<(src)) + if (isa(src)) logFile << src->getName(); else logFile << *src; @@ -52,15 +53,16 @@ void FIRegSelector::getFIInstRegMap( } } } - + // Insert an instruction for FI only if the regList is non-empty - if (reglist->size() != 0) { - // dbgs() << "Inserting FI function for instruction " << *inst << " " << reglist->size() << "\n"; - instregmap->insert( - std::pair* >(inst, reglist)); + if (!reglist->empty()) { + // dbgs() << "Inserting FI function for instruction " << *inst << " " << + // reglist->size() << "\n"; + instregmap->insert( + std::pair *>(inst, reglist)); } else if (!err) { - logFile << "The selected instruction " << *inst << - "does not have any valid registers for fault injection\n"; + logFile << "The selected instruction " << *inst + << "does not have any valid registers for fault injection\n"; } } logFile.close(); @@ -81,8 +83,9 @@ bool FIRegSelector::isRegofInstInjectable(Value *reg, Instruction *inst) { return true; } -bool FIRegSelector::isRegofInstFITarget(Value* reg, Instruction* inst, int pos){ +bool FIRegSelector::isRegofInstFITarget(Value *reg, Instruction *inst, + int pos) { return isRegofInstFITarget(reg, inst); } -} +} // namespace llfi diff --git a/llvm_passes/core/FIRegSelector.h b/llvm_passes/core/FIRegSelector.h index f701b366..24662b40 100644 --- a/llvm_passes/core/FIRegSelector.h +++ b/llvm_passes/core/FIRegSelector.h @@ -1,43 +1,36 @@ #ifndef FI_REG_SELECTOR_H #define FI_REG_SELECTOR_H +#include "Controller.h" + #include "llvm/IR/Instruction.h" #include "llvm/IR/Value.h" -#include "Controller.h" -#include -#include #include +#include +#include #include using namespace llvm; namespace llfi { class FIRegSelector { - public: - void getFIInstRegMap(const std::set< Instruction* > *instset, - std::map* > *instregmap); - virtual std::string getRegSelectorClass(){ - return std::string("Unknown"); - } +public: + void getFIInstRegMap(const std::set *instset, + std::map *> *instregmap); + virtual std::string getRegSelectorClass() { return std::string("Unknown"); } - private: +private: virtual bool isRegofInstFITarget(Value *reg, Instruction *inst) = 0; - virtual bool isRegofInstFITarget(Value* reg, Instruction* inst, int pos); + virtual bool isRegofInstFITarget(Value *reg, Instruction *inst, int pos); // determine whether LLFI is able to inject into the specified reg or not bool isRegofInstInjectable(Value *reg, Instruction *inst); }; -class SoftwareFIRegSelector: public FIRegSelector { - virtual std::string getRegSelectorClass(){ - return std::string("SoftwareFault"); - } -}; - -class HardwareFIRegSelector: public FIRegSelector { - virtual std::string getRegSelectorClass(){ - return std::string("HardwareFault"); - } +class HardwareFIRegSelector : public FIRegSelector { + std::string getRegSelectorClass() override { + return std::string("HardwareFault"); + } }; -} +} // namespace llfi #endif diff --git a/llvm_passes/core/FaultInjectionPass.cpp b/llvm_passes/core/FaultInjectionPass.cpp index b2640174..7bedf68a 100644 --- a/llvm_passes/core/FaultInjectionPass.cpp +++ b/llvm_passes/core/FaultInjectionPass.cpp @@ -11,26 +11,27 @@ // // The fault injection function is a C function which performs the fault // injection at runtime -// See faultinjection_lib.c injectFunc() function for more details on the -// fault injection function. This function definition is linked to the -// instrumented bitcode file (after this pass). +// See faultinjection_lib.c injectFunc() function for more details on the +// fault injection function. This function definition is linked to the +// instrumented bitcode file (after this pass). //===----------------------------------------------------------------------===// +#include "FaultInjectionPass.h" + +#include "Controller.h" +#include "Utils.h" + +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/DataLayout.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include -#include "FaultInjectionPass.h" -#include "Controller.h" -#include "Utils.h" - namespace llfi { -char FaultInjectionPass::ID=0; +char FaultInjectionPass::ID = 0; std::string FaultInjectionPass::getFIFuncNameforType(const Type *type) { std::string funcname; @@ -38,7 +39,7 @@ std::string FaultInjectionPass::getFIFuncNameforType(const Type *type) { funcname = fi_rettype_funcname_map[type]; } else { funcname = "injectFault"; - int ficount = fi_rettype_funcname_map.size(); + int ficount = (int)fi_rettype_funcname_map.size(); funcname += intToString(ficount); fi_rettype_funcname_map[type] = funcname; } @@ -46,60 +47,68 @@ std::string FaultInjectionPass::getFIFuncNameforType(const Type *type) { } void FaultInjectionPass::insertInjectionFuncCall( - std::map* > *inst_regs_map, Module &M) { + std::map *> *inst_regs_map, Module &M) { - for (std::map* >::iterator inst_reg_it = - inst_regs_map->begin(); inst_reg_it != inst_regs_map->end(); - ++inst_reg_it) { + for (std::map *>::iterator inst_reg_it = + inst_regs_map->begin(); + inst_reg_it != inst_regs_map->end(); ++inst_reg_it) { Instruction *fi_inst = inst_reg_it->first; - + std::list *fi_reg_pos_list = inst_reg_it->second; - /*BEHROOZ: This section makes sure that we do not instrument the intrinsic functions*/ - if(isa(fi_inst)){ - bool continue_flag=false; + // Skip intrinsic functions to avoid invalid instrumentation + if (isa(fi_inst)) { + bool continue_flag = false; for (std::list::iterator reg_pos_it_mem = fi_reg_pos_list->begin(); - (reg_pos_it_mem != fi_reg_pos_list->end()) && (*reg_pos_it_mem != DST_REG_POS); ++reg_pos_it_mem) { + (reg_pos_it_mem != fi_reg_pos_list->end()) && + (*reg_pos_it_mem != DST_REG_POS); + ++reg_pos_it_mem) { std::string reg_mem = fi_inst->getOperand(*reg_pos_it_mem)->getName().str(); - if ((reg_mem.find("memcpy") != std::string::npos) || (reg_mem.find("memset") != std::string::npos) || (reg_mem.find("expect") != std::string::npos) || (reg_mem.find("memmove") != std::string::npos)){ - continue_flag=true; + if ((reg_mem.find("memcpy") != std::string::npos) || + (reg_mem.find("memset") != std::string::npos) || + (reg_mem.find("expect") != std::string::npos) || + (reg_mem.find("memmove") != std::string::npos)) { + continue_flag = true; break; } } - if(continue_flag) + if (continue_flag) continue; } - /*BEHROOZ: This is to make sure we do not instrument landingpad instructions.*/ + // Skip landingpad instructions which cannot be instrumented std::string current_opcode = fi_inst->getOpcodeName(); - if(current_opcode.find("landingpad") != std::string::npos){ + if (current_opcode.find("landingpad") != std::string::npos) { continue; } unsigned reg_index = 0; unsigned total_reg_num = fi_reg_pos_list->size(); - for (std::list::iterator reg_pos_it = fi_reg_pos_list->begin(); + for (std::list::iterator reg_pos_it = fi_reg_pos_list->begin(); reg_pos_it != fi_reg_pos_list->end(); ++reg_pos_it, ++reg_index) { - if(isa(fi_inst)){ - GetElementPtrInst* gepi = dyn_cast(fi_inst); + if (isa(fi_inst)) { + GetElementPtrInst *gepi = dyn_cast(fi_inst); gepi->setIsInBounds(false); } - if(isa(fi_inst)){ - CallInst* ci = dyn_cast(fi_inst); + if (isa(fi_inst)) { + CallInst *ci = dyn_cast(fi_inst); ci->setTailCall(false); } - Value* fi_reg = NULL; - if(*reg_pos_it == DST_REG_POS) fi_reg = fi_inst; - else fi_reg = fi_inst->getOperand(*reg_pos_it); - //if(isa(fi_reg)) continue; + Value *fi_reg = nullptr; + if (*reg_pos_it == DST_REG_POS) + fi_reg = fi_inst; + else + fi_reg = fi_inst->getOperand(*reg_pos_it); + // if(isa(fi_reg)) continue; Type *returntype = fi_reg->getType(); - + // The return type is not a valid return type, and should hence be ignored - // This is to deal with types such as Metadata that are not valid return types - // Or if the instruction is an intrinsic one (starts with 'llvm_', ignore it - if ( !FunctionType::isValidReturnType(returntype) || - isa(fi_inst) ) - continue; + // This is to deal with types such as Metadata that are not valid return + // types Or if the instruction is an intrinsic one (starts with 'llvm_', + // ignore it + if (!FunctionType::isValidReturnType(returntype) || + isa(fi_inst)) + continue; // Get the context for the function LLVMContext &context = M.getContext(); @@ -107,23 +116,20 @@ void FaultInjectionPass::insertInjectionFuncCall( Type *i32type = Type::getInt32Ty(context); // function declaration - std::vector paramtypes(7); - paramtypes[0] = i64type; // llfi index - paramtypes[1] = returntype; // the instruction to be injected - paramtypes[2] = i32type; // opcode - paramtypes[3] = i32type; // current fi reg index - paramtypes[4] = i32type; // total fi reg number - //======== Add reg_pos QINING @DEC 23rd ========== + std::vector paramtypes(7); + paramtypes[0] = i64type; // llfi index + paramtypes[1] = returntype; // the instruction to be injected + paramtypes[2] = i32type; // opcode + paramtypes[3] = i32type; // current fi reg index + paramtypes[4] = i32type; // total fi reg number paramtypes[5] = i32type; - //================================================ - //======== Add opcode_str QINING @MAR 11th======== paramtypes[6] = PointerType::get(Type::getInt8Ty(context), 0); - //================================================ - //LLVM 3.3 Upgrade - ArrayRef paramtypes_array_ref(paramtypes); + // LLVM 3.3 Upgrade + ArrayRef paramtypes_array_ref(paramtypes); // dbgs() << "Getting function of type : " << *returntype <<"\n"; - FunctionType* injectfunctype = FunctionType::get(returntype, paramtypes_array_ref, false); + FunctionType *injectfunctype = + FunctionType::get(returntype, paramtypes_array_ref, false); std::string funcname = getFIFuncNameforType(returntype); FunctionCallee injectfunc = @@ -134,41 +140,41 @@ void FaultInjectionPass::insertInjectionFuncCall( // injection into "the instruction", use instruction's index instead Value *indexval = ConstantInt::get(i64type, getLLFIIndexofInst(fi_inst)); - std::vector args(7); - args[0] = indexval; //llfi index - args[1] = fi_reg; // target register - args[2] = ConstantInt::get(i32type, fi_inst->getOpcode()); // opcode in i32 + std::vector args(7); + args[0] = indexval; // llfi index + args[1] = fi_reg; // target register + args[2] = + ConstantInt::get(i32type, fi_inst->getOpcode()); // opcode in i32 args[3] = ConstantInt::get(i32type, reg_index); // reg_index not reg_pos args[4] = ConstantInt::get(i32type, total_reg_num); // total_reg_num - //======== Add reg_pos QINING @DEC 23rd ========== - args[5] = ConstantInt::get(i32type, *reg_pos_it+1); // dstreg->0, operand0->1, operand1->2 ... - //================================================ - //======== Add opcode_str QINING @MAR 11th======== + args[5] = ConstantInt::get( + i32type, *reg_pos_it + 1); // dstreg->0, operand0->1, operand1->2 ... std::string opcode_str = fi_inst->getOpcodeName(); - GlobalVariable* opcode_str_gv = findOrCreateGlobalNameString(M, opcode_str); - std::vector indices_for_gep(2); - indices_for_gep[0] = ConstantInt::get(Type::getInt32Ty(context),0); - indices_for_gep[1] = ConstantInt::get(Type::getInt32Ty(context),0); - ArrayRef indices_for_gep_array_ref(indices_for_gep); + GlobalVariable *opcode_str_gv = + findOrCreateGlobalNameString(M, opcode_str); + std::vector indices_for_gep(2); + indices_for_gep[0] = ConstantInt::get(Type::getInt32Ty(context), 0); + indices_for_gep[1] = ConstantInt::get(Type::getInt32Ty(context), 0); + ArrayRef indices_for_gep_array_ref(indices_for_gep); Constant *opc_str = dyn_cast(opcode_str_gv); Type *ty = opcode_str_gv->getValueType(); Constant *gep_expr = ConstantExpr::getGetElementPtr( ty, opc_str, indices_for_gep_array_ref, true); args[6] = gep_expr; // opcode in string - //================================================ // LLVM 3.3 Upgrade - ArrayRef args_array_ref(args); + ArrayRef args_array_ref(args); Instruction *insertptr = getInsertPtrforRegsofInst(fi_reg, fi_inst); - Instruction *ficall = - CallInst::Create(injectfunc, args_array_ref, "fi", insertptr); - setInjectFaultInst(fi_reg, fi_inst, ficall); // sets the instruction metadata + Instruction *ficall = CallInst::Create(injectfunc, args_array_ref, "fi", + insertptr->getIterator()); + setInjectFaultInst(fi_reg, fi_inst, + ficall); // sets the instruction metadata // redirect the data dependencies if (fi_reg == fi_inst) { // inject into destination - std::list inst_uses; + std::list inst_uses; for (Value::user_iterator user_it = fi_inst->user_begin(); user_it != fi_inst->user_end(); ++user_it) { User *user = *user_it; @@ -182,12 +188,12 @@ void FaultInjectionPass::insertInjectionFuncCall( user_it != inst_uses.end(); ++user_it) { User *user = *user_it; user->replaceUsesOfWith(fi_inst, ficall); - + // update the selected inst pool /*if (Instruction *use_inst = dyn_cast(user)) { if (inst_regs_map->find(use_inst) != inst_regs_map->end()) { std::list *reg_pos_list = (*inst_regs_map)[use_inst]; - + for (std::list::iterator reg_pos_it = reg_pos_list->begin(); reg_pos_it != reg_pos_list->end(); ++reg_pos_it) { if (use_inst->getOperand(*reg_pos_it) == fi_inst) { @@ -200,7 +206,7 @@ void FaultInjectionPass::insertInjectionFuncCall( } else { // inject into source - //fi_inst->replaceUsesOfWith(fi_reg, ficall); + // fi_inst->replaceUsesOfWith(fi_reg, ficall); fi_inst->setOperand(*reg_pos_it, ficall); } } @@ -212,53 +218,50 @@ void FaultInjectionPass::createInjectionFuncforType( FunctionCallee pre_fi_func) { LLVMContext &context = M.getContext(); Function *f = M.getFunction(fi_name); - std::vector args; - for(Function::arg_iterator ai = f->arg_begin(); ai != f->arg_end(); ++ai) + std::vector args; + for (Function::arg_iterator ai = f->arg_begin(); ai != f->arg_end(); ++ai) args.push_back(&*ai); // args[0] llfi index, args[1] fault injection instruction // args[2] for opcode, args[3] for reg index, args[4] for total num of fi reg - BasicBlock* entryblock = BasicBlock::Create(context, "entry", f); + BasicBlock *entryblock = BasicBlock::Create(context, "entry", f); // store the value of target instruction to memory AllocaInst *tmploc = new AllocaInst(fitype, 0, "tmploc", entryblock); new StoreInst(args[1], tmploc, entryblock); - std::vector pre_fi_args(4); - pre_fi_args[0] = args[0]; //LLFI index - pre_fi_args[1] = args[2]; //opcode in i32 - pre_fi_args[2] = args[3]; //reg_index, not reg_pos! - pre_fi_args[3] = args[4]; //total_reg_target_num + std::vector pre_fi_args(4); + pre_fi_args[0] = args[0]; // LLFI index + pre_fi_args[1] = args[2]; // opcode in i32 + pre_fi_args[2] = args[3]; // reg_index, not reg_pos! + pre_fi_args[3] = args[4]; // total_reg_target_num // LLVM 3.3 Upgrade - ArrayRef pre_fi_args_array_ref(pre_fi_args); + ArrayRef pre_fi_args_array_ref(pre_fi_args); - Value *prefuncval = CallInst::Create(pre_fi_func, pre_fi_args_array_ref, "pre_cond", entryblock); + Value *prefuncval = CallInst::Create(pre_fi_func, pre_fi_args_array_ref, + "pre_cond", entryblock); - BasicBlock *fiblock = BasicBlock::Create(context, "inject", f); - BasicBlock *exitblock = BasicBlock::Create(context,"exit", f ); - //if prefuncval is true, goto inject function - BranchInst::Create(fiblock, exitblock, prefuncval, entryblock); + BasicBlock *fiblock = BasicBlock::Create(context, "inject", f); + BasicBlock *exitblock = BasicBlock::Create(context, "exit", f); + // if prefuncval is true, goto inject function + BranchInst::Create(fiblock, exitblock, prefuncval, entryblock); BranchInst *fi2exit_branch = BranchInst::Create(exitblock, fiblock); - std::vector fi_args(6); - fi_args[0] = args[0]; //LLFI index + std::vector fi_args(6); + fi_args[0] = args[0]; // LLFI index const DataLayout &td = M.getDataLayout(); - int size = td.getTypeSizeInBits(fitype); - fi_args[1] = ConstantInt::get(Type::getInt32Ty(context), size); //size - fi_args[2] = new BitCastInst(tmploc, - PointerType::get(Type::getInt8Ty(context), 0), - "tmploc_cast", fi2exit_branch); //pointer to target memory - fi_args[3] = args[3]; //reg_index not reg_pos! - //======== Add reg_pos QINING @DEC 23rd ========== - fi_args[4] = args[5]; // dstreg->0, operand0->1, operand1->2 ... - //================================================ - //======== Add opcode_str QINING @MAR 11th======== - fi_args[5] = args[6]; //opcode in string - //================================================ - ArrayRef fi_args_array_ref(fi_args); - + int size = (int)td.getTypeSizeInBits(fitype); + fi_args[1] = ConstantInt::get(Type::getInt32Ty(context), size); // size + fi_args[2] = new BitCastInst( + tmploc, PointerType::get(Type::getInt8Ty(context), 0), "tmploc_cast", + fi2exit_branch->getIterator()); // pointer to target memory + fi_args[3] = args[3]; // reg_index not reg_pos! + fi_args[4] = args[5]; // dstreg->0, operand0->1, operand1->2 ... + fi_args[5] = args[6]; // opcode in string + ArrayRef fi_args_array_ref(fi_args); + CallInst::Create(injectfunc, fi_args_array_ref, "", - fi2exit_branch); + fi2exit_branch->getIterator()); LoadInst *updateval = new LoadInst(fitype, tmploc, "updateval", exitblock); ReturnInst::Create(context, updateval, exitblock); @@ -268,24 +271,24 @@ void FaultInjectionPass::createInjectionFunctions(Module &M) { FunctionCallee pre_fi_func = getLLFILibPreFIFunc(M); FunctionCallee injectfunc = getLLFILibFIFunc(M); - for (std::map::const_iterator fi = - fi_rettype_funcname_map.begin(); + for (std::map::const_iterator fi = + fi_rettype_funcname_map.begin(); fi != fi_rettype_funcname_map.end(); ++fi) { const Type *fi_type = fi->first; - // LLVM 3.3 upgrading - Type *fi_type_unconst = const_cast(fi_type); + Type *fi_type_unconst = const_cast(fi_type); std::string fi_name = fi->second; - createInjectionFuncforType(M, fi_type_unconst, fi_name, injectfunc, pre_fi_func); + createInjectionFuncforType(M, fi_type_unconst, fi_name, injectfunc, + pre_fi_func); } } bool FaultInjectionPass::runOnModule(Module &M) { checkforMainFunc(M); - std::map* > *fi_inst_regs_map; + std::map *> *fi_inst_regs_map = nullptr; Controller *ctrl = Controller::getInstance(M); ctrl->getFIInstRegsMap(&fi_inst_regs_map); insertInjectionFuncCall(fi_inst_regs_map, M); @@ -295,10 +298,10 @@ bool FaultInjectionPass::runOnModule(Module &M) { } void FaultInjectionPass::checkforMainFunc(Module &M) { - Function* mainfunc = M.getFunction("main"); - if (mainfunc == NULL) { - errs() << "ERROR: Function main does not exist, " << - "which is required by LLFI\n"; + Function *mainfunc = M.getFunction("main"); + if (mainfunc == nullptr) { + errs() << "ERROR: Function main does not exist, " + << "which is required by LLFI\n"; exit(1); } } @@ -309,34 +312,33 @@ void FaultInjectionPass::finalize(Module &M) { // function call for initInjections FunctionCallee initfunc = getLLFILibInitInjectionFunc(M); - CallInst::Create(initfunc, "", entryblock->getFirstNonPHI()); - + CallInst::Create(initfunc, "", entryblock->getFirstNonPHIIt()); + // function call for postInjections FunctionCallee postfifunc = getLLFILibPostInjectionFunc(M); - std::set exitinsts; + std::set exitinsts; getProgramExitInsts(M, exitinsts); - assert (exitinsts.size() != 0 - && "Program does not have explicit exit point"); - for (std::set::iterator it = exitinsts.begin(); - it != exitinsts.end(); ++it) { + assert(!exitinsts.empty() && "Program does not have explicit exit point"); + for (std::set::iterator it = exitinsts.begin(); + it != exitinsts.end(); ++it) { Instruction *term = *it; - CallInst::Create(postfifunc, "", term); + CallInst::Create(postfifunc, "", term->getIterator()); } - + createInjectionFunctions(M); } FunctionCallee FaultInjectionPass::getLLFILibPreFIFunc(Module &M) { - std::vector pre_fi_func_param_types(4); - LLVMContext& context = M.getContext(); - pre_fi_func_param_types[0] = Type::getInt64Ty(context);// index - pre_fi_func_param_types[1] = Type::getInt32Ty(context);// opcode - pre_fi_func_param_types[2] = Type::getInt32Ty(context);// my reg index - pre_fi_func_param_types[3] = Type::getInt32Ty(context);// total reg index num + std::vector pre_fi_func_param_types(4); + LLVMContext &context = M.getContext(); + pre_fi_func_param_types[0] = Type::getInt64Ty(context); // index + pre_fi_func_param_types[1] = Type::getInt32Ty(context); // opcode + pre_fi_func_param_types[2] = Type::getInt32Ty(context); // my reg index + pre_fi_func_param_types[3] = Type::getInt32Ty(context); // total reg index num // LLVM 3.3 Upgrade - ArrayRef pre_fi_func_param_types_array_ref(pre_fi_func_param_types); + ArrayRef pre_fi_func_param_types_array_ref(pre_fi_func_param_types); FunctionType *pre_fi_func_type = FunctionType::get( Type::getInt1Ty(context), pre_fi_func_param_types_array_ref, false); @@ -346,19 +348,18 @@ FunctionCallee FaultInjectionPass::getLLFILibPreFIFunc(Module &M) { } FunctionCallee FaultInjectionPass::getLLFILibFIFunc(Module &M) { - LLVMContext& context = M.getContext(); - std::vector fi_func_param_types(6); + LLVMContext &context = M.getContext(); + std::vector fi_func_param_types(6); fi_func_param_types[0] = Type::getInt64Ty(context); // index fi_func_param_types[1] = Type::getInt32Ty(context); // size - fi_func_param_types[2] = PointerType::get(Type::getInt8Ty(context), 0); //inst + fi_func_param_types[2] = + PointerType::get(Type::getInt8Ty(context), 0); // inst fi_func_param_types[3] = Type::getInt32Ty(context); // my reg index fi_func_param_types[4] = Type::getInt32Ty(context); // reg_pos - //======== Add opcode_str QINING @MAR 11th======== fi_func_param_types[5] = PointerType::get(Type::getInt8Ty(context), 0); - //================================================ // LLVM 3.3 Upgrade - ArrayRef fi_func_param_types_array_ref(fi_func_param_types); + ArrayRef fi_func_param_types_array_ref(fi_func_param_types); FunctionType *injectfunctype = FunctionType::get( Type::getVoidTy(context), fi_func_param_types_array_ref, false); @@ -385,6 +386,6 @@ FunctionCallee FaultInjectionPass::getLLFILibPostInjectionFunc(Module &M) { return postfifunc; } -static RegisterPass X( - "faultinjectionpass", "Fault injection pass", false, true); -} +static RegisterPass X("faultinjectionpass", + "Fault injection pass", false, true); +} // namespace llfi diff --git a/llvm_passes/core/FaultInjectionPass.h b/llvm_passes/core/FaultInjectionPass.h index a6aa95a2..f7bc6904 100644 --- a/llvm_passes/core/FaultInjectionPass.h +++ b/llvm_passes/core/FaultInjectionPass.h @@ -2,10 +2,10 @@ #define FAULTINJECTION_PASS_H #include "llvm/IR/Constants.h" -#include "llvm/Pass.h" -#include "llvm/IR/Module.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/PassPlugin.h" @@ -17,53 +17,52 @@ using namespace llvm; namespace llfi { - // For Legacy PM. - class FaultInjectionPass: public ModulePass { - public: - FaultInjectionPass() : ModulePass(ID) { } - virtual bool runOnModule(Module &M); - static char ID; +// For Legacy PM. +class FaultInjectionPass : public ModulePass { +public: + FaultInjectionPass() : ModulePass(ID) {} + bool runOnModule(Module &M) override; + static char ID; - private: - void checkforMainFunc(Module &M); - void finalize(Module& M); +private: + void checkforMainFunc(Module &M); + void finalize(Module &M); - void insertInjectionFuncCall( - std::map* > *inst_regs_map, Module &M); - void createInjectionFuncforType(Module &M, Type *functype, - std::string &funcname, FunctionCallee fi_func, - FunctionCallee pre_func); - void createInjectionFunctions(Module &M); + void insertInjectionFuncCall( + std::map *> *inst_regs_map, Module &M); + void createInjectionFuncforType(Module &M, Type *functype, + std::string &funcname, FunctionCallee fi_func, + FunctionCallee pre_func); + void createInjectionFunctions(Module &M); - private: - std::string getFIFuncNameforType(const Type* type); +private: + std::string getFIFuncNameforType(const Type *type); - FunctionCallee getLLFILibPreFIFunc(Module &M); - FunctionCallee getLLFILibFIFunc(Module &M); - FunctionCallee getLLFILibInitInjectionFunc(Module &M); - FunctionCallee getLLFILibPostInjectionFunc(Module &M); + FunctionCallee getLLFILibPreFIFunc(Module &M); + FunctionCallee getLLFILibFIFunc(Module &M); + FunctionCallee getLLFILibInitInjectionFunc(Module &M); + FunctionCallee getLLFILibPostInjectionFunc(Module &M); - private: - std::map fi_rettype_funcname_map; - }; +private: + std::map fi_rettype_funcname_map; +}; - // For New PM - struct NewFaultInjectionPass: llvm::PassInfoMixin { - llvm::PreservedAnalyses run(llvm::Module &M, - llvm::ModuleAnalysisManager &){ +// For New PM +struct NewFaultInjectionPass : llvm::PassInfoMixin { + llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &) { - auto obj = new FaultInjectionPass(); - bool isChanged = obj->runOnModule(M); + auto obj = new FaultInjectionPass(); + bool isChanged = obj->runOnModule(M); - delete obj; - return (isChanged) ? llvm::PreservedAnalyses::none(): - llvm::PreservedAnalyses::all(); - } + delete obj; + return (isChanged) ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all(); + } - // Without isRequired returning true, this pass will be skipped for functions - // decorated with the optnone LLVM attribute. Note that clang -O0 decorates - // all functions with optnone. - static bool isRequired() { return true; } - }; -} + // Without isRequired returning true, this pass will be skipped for functions + // decorated with the optnone LLVM attribute. Note that clang -O0 decorates + // all functions with optnone. + static bool isRequired() { return true; } +}; +} // namespace llfi #endif diff --git a/llvm_passes/core/GenLLFIIndexPass.cpp b/llvm_passes/core/GenLLFIIndexPass.cpp index ea77589c..a232b6ad 100644 --- a/llvm_passes/core/GenLLFIIndexPass.cpp +++ b/llvm_passes/core/GenLLFIIndexPass.cpp @@ -1,46 +1,48 @@ -#include "llvm/IR/Function.h" -#include +#include "GenLLFIIndexPass.h" #include "Utils.h" -#include "GenLLFIIndexPass.h" + +#include "llvm/IR/Function.h" + +#include using namespace llvm; namespace llfi { - // Main functionality of this pass - bool runOnModuleMain(Module &M) { - Instruction *currinst = NULL; - - for (Module::iterator m_it = M.begin(); m_it != M.end(); ++m_it) { - if (!m_it->isDeclaration()) { - // m_it is a function - for (inst_iterator f_it = inst_begin(&*m_it); f_it != inst_end(&*m_it); - ++f_it) { - currinst = &(*f_it); - setLLFIIndexofInst(currinst); - } +// Main functionality of this pass +bool runOnModuleMain(Module &M) { + Instruction *currinst = nullptr; + + for (Module::iterator m_it = M.begin(); m_it != M.end(); ++m_it) { + if (!m_it->isDeclaration()) { + // m_it is a function + for (inst_iterator f_it = inst_begin(&*m_it); f_it != inst_end(&*m_it); + ++f_it) { + currinst = &(*f_it); + setLLFIIndexofInst(currinst); } } + } - if (currinst) { - long totalindex = getLLFIIndexofInst(currinst); - FILE *outputFile = fopen("llfi.stat.totalindex.txt", "w"); - if (outputFile) - fprintf(outputFile, "totalindex=%ld\n", totalindex); - + if (currinst) { + long totalindex = getLLFIIndexofInst(currinst); + FILE *outputFile = fopen("llfi.stat.totalindex.txt", "w"); + if (outputFile) { + fprintf(outputFile, "totalindex=%ld\n", totalindex); fclose(outputFile); } - - return true; } - // Registration for old PM - char LegacyGenLLFIIndexPass::ID = 0; - static RegisterPass X( - "genllfiindexpass", "Generate a unique LLFI index for each instruction", + return true; +} + +// Registration for old PM +char LegacyGenLLFIIndexPass::ID = 0; +static RegisterPass + X("genllfiindexpass", "Generate a unique LLFI index for each instruction", false, false); - bool LegacyGenLLFIIndexPass::runOnModule(Module &M) { - return runOnModuleMain(M); - } +bool LegacyGenLLFIIndexPass::runOnModule(Module &M) { + return runOnModuleMain(M); } +} // namespace llfi diff --git a/llvm_passes/core/GenLLFIIndexPass.h b/llvm_passes/core/GenLLFIIndexPass.h index 6fb82f4d..76e09add 100644 --- a/llvm_passes/core/GenLLFIIndexPass.h +++ b/llvm_passes/core/GenLLFIIndexPass.h @@ -1,3 +1,6 @@ +#ifndef GEN_LLFI_INDEX_PASS_H +#define GEN_LLFI_INDEX_PASS_H + #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Module.h" @@ -6,29 +9,30 @@ using namespace llvm; -namespace llfi{ - - bool runOnModuleMain(Module&); - - // For new PM - struct GenLLFIIndexPass: llvm::PassInfoMixin { - PreservedAnalyses run(llvm::Module &M, - llvm::ModuleAnalysisManager &) { - runOnModuleMain(M); - return PreservedAnalyses::none(); - } - - // Without isRequired returning true, this pass will be skipped for functions - // decorated with the optnone LLVM attribute. Note that clang -O0 decorates - // all functions with optnone. - static bool isRequired() { return true; } - }; - - // For legacy PM - class LegacyGenLLFIIndexPass: public ModulePass { - public: - LegacyGenLLFIIndexPass() : ModulePass(ID) {} - virtual bool runOnModule(Module &M); - static char ID; - }; -} +namespace llfi { + +bool runOnModuleMain(Module &); + +// For new PM +struct GenLLFIIndexPass : llvm::PassInfoMixin { + PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &) { + runOnModuleMain(M); + return PreservedAnalyses::none(); + } + + // Without isRequired returning true, this pass will be skipped for functions + // decorated with the optnone LLVM attribute. Note that clang -O0 decorates + // all functions with optnone. + static bool isRequired() { return true; } +}; + +// For legacy PM +class LegacyGenLLFIIndexPass : public ModulePass { +public: + LegacyGenLLFIIndexPass() : ModulePass(ID) {} + bool runOnModule(Module &M) override; + static char ID; +}; +} // namespace llfi + +#endif // GEN_LLFI_INDEX_PASS_H diff --git a/llvm_passes/core/InstTracePass.cpp b/llvm_passes/core/InstTracePass.cpp index 8e730fc2..dfafe158 100644 --- a/llvm_passes/core/InstTracePass.cpp +++ b/llvm_passes/core/InstTracePass.cpp @@ -10,8 +10,9 @@ Author: Sam Coulter instruction to a file specified during the pass. ***************/ -#include -#include +#include "InstTracePass.h" + +#include "Utils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -27,205 +28,199 @@ Author: Sam Coulter #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "Utils.h" -#include "InstTracePass.h" +#include +#include -cl::opt debugtrace("debugtrace", - cl::desc("Print tracing instrucmented instruction information"), - cl::init(false)); -cl::opt maxtrace( "maxtrace", - cl::desc("Maximum number of dynamic instructions that will be traced after fault injection"), - cl::init(1000)); +cl::opt + debugtrace("debugtrace", + cl::desc("Print tracing instrucmented instruction information"), + cl::init(false)); +cl::opt maxtrace("maxtrace", + cl::desc("Maximum number of dynamic instructions that " + "will be traced after fault injection"), + cl::init(1000)); using namespace llvm; namespace llfi { - bool InstTrace::doFinalization(Module &M) { - //Dont forget to delete the output filename string! - Function* mainfunc = M.getFunction("main"); - if (mainfunc == NULL) { - errs() << "ERROR: Function main does not exist, " << - "which is required by LLFI\n"; - exit(1); - } +bool InstTrace::doFinalization(Module &M) { + // Dont forget to delete the output filename string! + Function *mainfunc = M.getFunction("main"); + if (mainfunc == nullptr) { + errs() << "ERROR: Function main does not exist, " + << "which is required by LLFI\n"; + exit(1); + } - LLVMContext &context = M.getContext(); - FunctionType *postinjectfunctype = - FunctionType::get(Type::getVoidTy(context), false); - FunctionCallee postracingfunc = - M.getOrInsertFunction("postTracing", postinjectfunctype); - - std::set exitinsts; - getProgramExitInsts(M, exitinsts); - assert (exitinsts.size() != 0 - && "Program does not have explicit exit point"); - - for (std::set::iterator it = exitinsts.begin(); - it != exitinsts.end(); ++it) { - Instruction *term = *it; - CallInst::Create(postracingfunc, "", term); - } + LLVMContext &context = M.getContext(); + FunctionType *postinjectfunctype = + FunctionType::get(Type::getVoidTy(context), false); + FunctionCallee postracingfunc = + M.getOrInsertFunction("postTracing", postinjectfunctype); - return true; - } + std::set exitinsts; + getProgramExitInsts(M, exitinsts); + assert(!exitinsts.empty() && "Program does not have explicit exit point"); - long InstTrace::fetchLLFIInstructionID(Instruction *targetInst) { - return llfi::getLLFIIndexofInst(targetInst); + for (std::set::iterator it = exitinsts.begin(); + it != exitinsts.end(); ++it) { + Instruction *term = *it; + CallInst::Create(postracingfunc, "", term->getIterator()); } - Instruction* InstTrace::getInsertPoint(Instruction* llfiIndexedInst) { - Instruction *insertPoint; - if (!llfiIndexedInst->isTerminator()) { - insertPoint = llfi::getInsertPtrforRegsofInst(llfiIndexedInst, llfiIndexedInst); - // if insert point is a call to inject fault, insert printInstTrace after the injectFault call - // iff injectFault occurs AFTER the targeted instruction (ie. dst targeted) - insertPoint = changeInsertPtrIfInjectFaultInst(insertPoint); - } else { - // if terminator, insert before function - insertPoint = llfiIndexedInst; - } - return insertPoint; + return true; +} + +long InstTrace::fetchLLFIInstructionID(Instruction *targetInst) { + return llfi::getLLFIIndexofInst(targetInst); +} + +Instruction *InstTrace::getInsertPoint(Instruction *llfiIndexedInst) { + Instruction *insertPoint = nullptr; + if (!llfiIndexedInst->isTerminator()) { + insertPoint = + llfi::getInsertPtrforRegsofInst(llfiIndexedInst, llfiIndexedInst); + // if insert point is a call to inject fault, insert printInstTrace after + // the injectFault call iff injectFault occurs AFTER the targeted + // instruction (ie. dst targeted) + insertPoint = changeInsertPtrIfInjectFaultInst(insertPoint); + } else { + // if terminator, insert before function + insertPoint = llfiIndexedInst; } + return insertPoint; +} + +bool InstTrace::runOnFunction(Function &F) { + // Create handles to the functions parent module and context + LLVMContext &context = F.getContext(); + Module *M = F.getParent(); + + // iterate through each instruction of the function + inst_iterator lastInst; + for (inst_iterator instIterator = inst_begin(F), lastInst = inst_end(F); + instIterator != lastInst; ++instIterator) { + + // Print some Debug Info as the pass is being run + Instruction *inst = &*instIterator; + + if (debugtrace) { + if (!llfi::isLLFIIndexedInst(inst)) { + errs() << "Instruction " << *inst << " was not indexed\n"; + } else { + errs() << "Instruction " << *inst << " was indexed\n"; + } + } + + if (llfi::isLLFIIndexedInst(inst)) { + + // Find instrumentation point for current instruction + Instruction *insertPoint = getInsertPoint(inst); - bool InstTrace::runOnFunction(Function &F) { - //Create handles to the functions parent module and context - LLVMContext& context = F.getContext(); - Module *M = F.getParent(); - - //iterate through each instruction of the function - inst_iterator lastInst; - for (inst_iterator instIterator = inst_begin(F), - lastInst = inst_end(F); - instIterator != lastInst; ++instIterator) { - - //Print some Debug Info as the pass is being run - Instruction *inst = &*instIterator; - - if (debugtrace) { - if (!llfi::isLLFIIndexedInst(inst)) { - errs() << "Instruction " << *inst << " was not indexed\n"; - } else { - errs() << "Instruction " << *inst << " was indexed\n"; - } + // Skip instrumentation for terminating instructions + if (insertPoint->isTerminator()) { + continue; } - if (llfi::isLLFIIndexedInst(inst)) { - - //Find instrumentation point for current instruction - Instruction *insertPoint = getInsertPoint(inst); - - //Skip instrumentation for terminating instructions - if (insertPoint->isTerminator()) { - continue; - } - - //======== Find insertion location for alloca QINING @SET 15th============ - Instruction* alloca_insertPoint = inst->getParent()->getParent()->begin()->getFirstNonPHIOrDbgOrLifetime(); - //======================================================================== - - - //Fetch size of instruction value - //The size must be rounded up before conversion to bytes because some data in llvm - //can be like 1 bit if it only needs 1 bit out of an 8bit/1byte data type - float bitSize; - AllocaInst* ptrInst; - if (inst->getType() != Type::getVoidTy(context)) { - //insert an instruction Allocate stack memory to store/pass instruction value - ptrInst = new AllocaInst(inst->getType(), 0, "llfi_trace", - alloca_insertPoint); - //Insert an instruction to Store the instruction Value! - new StoreInst(inst, ptrInst, insertPoint); - - const DataLayout &td = F.getParent()->getDataLayout(); - bitSize = (float)td.getTypeSizeInBits(inst->getType()); - } - else { - ptrInst = new AllocaInst(Type::getInt32Ty(context), 0, "llfi_trace", - alloca_insertPoint); - new StoreInst(ConstantInt::get(IntegerType::get(context, 32), 0), - ptrInst, insertPoint); - bitSize = 32; - } - int byteSize = (int)ceil(bitSize / 8.0); - - //Insert instructions to allocate stack memory for opcode name - const char* opcodeNamePt = inst->getOpcodeName(); - const std::string str(inst->getOpcodeName()); - ArrayRef opcode_name_array_ref((uint8_t*)opcodeNamePt, str.size() + 1); - //llvm::Value* OPCodeName = llvm::ConstantArray::get(context, opcode_name_array_ref); - llvm::Value* OPCodeName = llvm::ConstantDataArray::get(context, opcode_name_array_ref); - /********************************/ - - AllocaInst *OPCodePtr = new AllocaInst( - OPCodeName->getType(), 0, "llfi_trace", alloca_insertPoint); - new StoreInst(OPCodeName, OPCodePtr, insertPoint); - - //Create the decleration of the printInstTracer Function - std::vector parameterVector(5); - parameterVector[0] = Type::getInt32Ty(context); //ID - parameterVector[1] = OPCodePtr->getType(); - //======== opcode_str QINING @SET 15th============ - //parameterVector[1] = PointerType::get(Type::getInt8Ty(context), 0); //Ptr to OpCode - //================================================ - parameterVector[2] = Type::getInt32Ty(context); //Size of Inst Value - parameterVector[3] = ptrInst->getType(); //Ptr to Inst Value - parameterVector[4] = Type::getInt32Ty(context); //Int of max traces - - //LLVM 3.3 Upgrade - ArrayRef parameterVector_array_ref(parameterVector); - - FunctionType* traceFuncType = FunctionType::get(Type::getVoidTy(context), - parameterVector_array_ref, false); - FunctionCallee traceFunc = - M->getOrInsertFunction("printInstTracer", traceFuncType); - - //Insert the tracing function, passing it the proper arguments - std::vector traceArgs; - //Fetch the LLFI Instruction ID: - ConstantInt* IDConstInt = ConstantInt::get(IntegerType::get(context, 32), - fetchLLFIInstructionID(inst)); - - ConstantInt* instValSize = ConstantInt::get( - IntegerType::get(context, 32), byteSize); - - //Fetch maxtrace number: - ConstantInt* maxTraceConstInt = - ConstantInt::get(IntegerType::get(context, 32), maxtrace); - - //======== opcode_str QINING @SET 15th============ - //string opcode_str = fi_inst->getOpcodeName(); - //GlobalVariable* opcode_str_gv = findOrCreateGlobalNameString(M, opcode_str); - //vector indices_for_gep(2); - //indices_for_gep[0] = ConstantInt::get(Type::getInt32Ty(context),0); - //indices_for_gep[1] = ConstantInt::get(Type::getInt32Ty(context),0); - //ArrayRef gep_expr_ref(indices_for_gep); - //Constant* gep_expr_opcode = ConstantExpr::getGetElementPtr(opcode_str_gv, gep_expr_ref); - //================================================ - - //Load All Arguments - traceArgs.push_back(IDConstInt); - traceArgs.push_back(OPCodePtr); - traceArgs.push_back(instValSize); - traceArgs.push_back(ptrInst); - traceArgs.push_back(maxTraceConstInt); - - //LLVM 3.3 Upgrade - ArrayRef traceArgs_array_ref(traceArgs); - - //Create the Function - CallInst::Create(traceFunc, traceArgs_array_ref, "", insertPoint); + Instruction *alloca_insertPoint = inst->getParent() + ->getParent() + ->begin() + ->getFirstNonPHIOrDbgOrLifetime(); + + // Fetch size of instruction value + // The size must be rounded up before conversion to bytes because some + // data in llvm can be like 1 bit if it only needs 1 bit out of an + // 8bit/1byte data type + float bitSize = 0.0f; + AllocaInst *ptrInst = nullptr; + if (inst->getType() != Type::getVoidTy(context)) { + // insert an instruction Allocate stack memory to store/pass instruction + // value + ptrInst = new AllocaInst(inst->getType(), 0, "llfi_trace", + alloca_insertPoint); + // Insert an instruction to Store the instruction Value! + new StoreInst(inst, ptrInst, insertPoint->getIterator()); + + const DataLayout &td = F.getParent()->getDataLayout(); + bitSize = (float)td.getTypeSizeInBits(inst->getType()); + } else { + ptrInst = new AllocaInst(Type::getInt32Ty(context), 0, "llfi_trace", + alloca_insertPoint); + new StoreInst(ConstantInt::get(IntegerType::get(context, 32), 0), + ptrInst, insertPoint->getIterator()); + bitSize = 32; } - }//Function Iteration + int byteSize = (int)ceil(bitSize / 8.0); + + // Insert instructions to allocate stack memory for opcode name + const char *opcodeNamePt = inst->getOpcodeName(); + const std::string str(inst->getOpcodeName()); + ArrayRef opcode_name_array_ref( + reinterpret_cast(opcodeNamePt), str.size() + 1); + // llvm::Value* OPCodeName = llvm::ConstantArray::get(context, + // opcode_name_array_ref); + llvm::Value *OPCodeName = + llvm::ConstantDataArray::get(context, opcode_name_array_ref); + /********************************/ + + AllocaInst *OPCodePtr = new AllocaInst(OPCodeName->getType(), 0, + "llfi_trace", alloca_insertPoint); + new StoreInst(OPCodeName, OPCodePtr, insertPoint->getIterator()); + + // Create the decleration of the printInstTracer Function + std::vector parameterVector(5); + parameterVector[0] = Type::getInt32Ty(context); // ID + parameterVector[1] = OPCodePtr->getType(); + parameterVector[2] = Type::getInt32Ty(context); // Size of Inst Value + parameterVector[3] = ptrInst->getType(); // Ptr to Inst Value + parameterVector[4] = Type::getInt32Ty(context); // Int of max traces + + // LLVM 3.3 Upgrade + ArrayRef parameterVector_array_ref(parameterVector); + + FunctionType *traceFuncType = FunctionType::get( + Type::getVoidTy(context), parameterVector_array_ref, false); + FunctionCallee traceFunc = + M->getOrInsertFunction("printInstTracer", traceFuncType); + + // Insert the tracing function, passing it the proper arguments + std::vector traceArgs; + // Fetch the LLFI Instruction ID: + ConstantInt *IDConstInt = ConstantInt::get(IntegerType::get(context, 32), + fetchLLFIInstructionID(inst)); + + ConstantInt *instValSize = + ConstantInt::get(IntegerType::get(context, 32), byteSize); + + // Fetch maxtrace number: + ConstantInt *maxTraceConstInt = + ConstantInt::get(IntegerType::get(context, 32), maxtrace); + + // Load All Arguments + traceArgs.push_back(IDConstInt); + traceArgs.push_back(OPCodePtr); + traceArgs.push_back(instValSize); + traceArgs.push_back(ptrInst); + traceArgs.push_back(maxTraceConstInt); + + // LLVM 3.3 Upgrade + ArrayRef traceArgs_array_ref(traceArgs); + + // Create the Function + CallInst::Create(traceFunc, traceArgs_array_ref, "", + insertPoint->getIterator()); + } + } // Function Iteration - return true; //Tell LLVM that the Function was modified - }//RunOnFunction + return true; // Tell LLVM that the Function was modified +} // RunOnFunction -//Register the pass with the llvm +// Register the pass with the llvm char InstTrace::ID = 0; static RegisterPass X("insttracepass", - "Add tracing function calls in program to trace instruction value at runtime", - false, false); - -}//namespace llfi + "Add tracing function calls in program to " + "trace instruction value at runtime", + false, false); +} // namespace llfi diff --git a/llvm_passes/core/InstTracePass.h b/llvm_passes/core/InstTracePass.h index 6cd2bdf0..78e6d65a 100644 --- a/llvm_passes/core/InstTracePass.h +++ b/llvm_passes/core/InstTracePass.h @@ -1,3 +1,6 @@ +#ifndef INST_TRACE_PASS_H +#define INST_TRACE_PASS_H + #include "llvm/IR/PassManager.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/PassPlugin.h" @@ -6,45 +9,43 @@ using namespace llvm; namespace llfi { - struct InstTrace : public FunctionPass { +struct InstTrace : public FunctionPass { - static char ID; + static char ID; - InstTrace() : FunctionPass(ID) {} + InstTrace() : FunctionPass(ID) {} - virtual bool doInitialization(Module &M) { - return false; - } + bool doInitialization(Module &M) override { return false; } - virtual bool doFinalization(Module &M); + bool doFinalization(Module &M) override; - long fetchLLFIInstructionID(Instruction *targetInst); + long fetchLLFIInstructionID(Instruction *targetInst); - Instruction* getInsertPoint(Instruction* llfiIndexedInst); + Instruction *getInsertPoint(Instruction *llfiIndexedInst); - virtual bool runOnFunction(Function &F); - }; + bool runOnFunction(Function &F) override; +}; - struct NewInstTrace: llvm::PassInfoMixin { - llvm::PreservedAnalyses run(llvm::Module &M, - llvm::ModuleAnalysisManager &){ +struct NewInstTrace : llvm::PassInfoMixin { + llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &) { - InstTrace tempObj; - tempObj.doInitialization(M); + InstTrace tempObj; + tempObj.doInitialization(M); - for (Function &F : M) { - tempObj.runOnFunction(F); - } + for (Function &F : M) { + tempObj.runOnFunction(F); + } - tempObj.doFinalization(M); + tempObj.doFinalization(M); - return PreservedAnalyses::none(); - } + return PreservedAnalyses::none(); + } - // Without isRequired returning true, this pass will be skipped for functions - // decorated with the optnone LLVM attribute. Note that clang -O0 decorates - // all functions with optnone. - static bool isRequired() { return true; } - }; -}//namespace llfi + // Without isRequired returning true, this pass will be skipped for functions + // decorated with the optnone LLVM attribute. Note that clang -O0 decorates + // all functions with optnone. + static bool isRequired() { return true; } +}; +} // namespace llfi +#endif // INST_TRACE_PASS_H diff --git a/llvm_passes/core/LLFIDotGraphPass.cpp b/llvm_passes/core/LLFIDotGraphPass.cpp index d2290420..8b422b1f 100644 --- a/llvm_passes/core/LLFIDotGraphPass.cpp +++ b/llvm_passes/core/LLFIDotGraphPass.cpp @@ -1,25 +1,26 @@ -#include -#include -#include +#include "Utils.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/GlobalValue.h" -#include "llvm/Pass.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/Support/Debug.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Value.h" -#include "Utils.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/PassPlugin.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +#include +#include +#include #define DATADEPCOLOUR "blue" @@ -47,17 +48,21 @@ instNode::instNode(Instruction *target) { label += std::string("\\n") + target->getOpcodeName() + "\\n"; DebugLoc dbgLoc = target->getDebugLoc(); if (bool(dbgLoc) && dbgLoc.getLine()) { - label += "(Line #: " + intToString(dbgLoc.getLine()) + ")\\n"; + label += "(Line #: " + intToString((int)dbgLoc.getLine()) + ")\\n"; /* if (MDNode *N= target->getMetadata("dbg")){ - label += "(In File: " + DILocation (N).getFilename().str().substr(DILocation (N).getFilename().str().find_last_of("/\\")+1)+")"; + label += "(In File: " + DILocation + (N).getFilename().str().substr(DILocation + (N).getFilename().str().find_last_of("/\\")+1)+")"; } */ if (outputFile) - fprintf(outputFile, "%s line_%s\n", name.c_str(),intToString(target->getDebugLoc().getLine()).c_str()); - } - else{ + fprintf(outputFile, "%s line_%s\n", name.c_str(), + intToString((int)target->getDebugLoc().getLine()).c_str()); + } else { if (outputFile) fprintf(outputFile, "%s line_N/A\n", name.c_str()); } + if (outputFile) + fclose(outputFile); label += "\"]"; } @@ -66,14 +71,14 @@ std::string instNode::dotNode() { } struct bBlockGraph { - BasicBlock* raw; + BasicBlock *raw; std::string name; std::string funcName; std::vector instNodes; - Instruction* entryInst; - Instruction* exitInst; + Instruction *entryInst; + Instruction *exitInst; bBlockGraph(BasicBlock *target); - bool addInstruction(Instruction* inst); + bool addInstruction(Instruction *inst); bool writeToStream(std::ofstream &target); }; @@ -82,10 +87,8 @@ bBlockGraph::bBlockGraph(BasicBlock *BB) { name = BB->getName().str(); funcName = BB->getParent()->getName().str(); BasicBlock::iterator lastInst; - for (BasicBlock::iterator instIterator = BB->begin(), - lastInst = BB->end(); - instIterator != lastInst; - ++instIterator) { + for (BasicBlock::iterator instIterator = BB->begin(), lastInst = BB->end(); + instIterator != lastInst; ++instIterator) { Instruction *inst = &*instIterator; @@ -94,7 +97,7 @@ bBlockGraph::bBlockGraph(BasicBlock *BB) { entryInst = &(BB->front()); exitInst = &(BB->back()); } -bool bBlockGraph::addInstruction(Instruction* inst) { +bool bBlockGraph::addInstruction(Instruction *inst) { instNodes.push_back(instNode(inst)); return true; @@ -108,44 +111,45 @@ bool bBlockGraph::writeToStream(std::ofstream &target) { } target << "}\n"; for (unsigned int i = 1; i < instNodes.size(); i++) { - target << instNodes.at(i-1).name << " -> " << instNodes.at(i).name << ";\n"; + target << instNodes.at(i - 1).name << " -> " << instNodes.at(i).name + << ";\n"; } return true; } - bool llfiDotGraph::runOnFunction(Function &F) { - //Create handles to the functions parent module and context + // Create handles to the functions parent module and context LLVMContext &context = F.getContext(); std::vector blocks; Function::iterator lastBlock; - //iterate through each basicblock of the function + // iterate through each basicblock of the function for (Function::iterator blockIterator = F.begin(), lastBlock = F.end(); - blockIterator != lastBlock; ++blockIterator) { + blockIterator != lastBlock; ++blockIterator) { - BasicBlock* block = &*blockIterator; + BasicBlock *block = &*blockIterator; bBlockGraph b(block); blocks.push_back(b); } for (unsigned int i = 0; i < blocks.size(); i++) { - bBlockGraph currBlock = blocks.at(i); + const bBlockGraph &currBlock = blocks.at(i); for (unsigned int i = 0; i < currBlock.instNodes.size(); i++) { Instruction *inst = currBlock.instNodes.at(i).raw; std::string nodeName = currBlock.instNodes.at(i).name; instNode node = currBlock.instNodes.at(i); if (!inst->use_empty()) { // TODO: optimize the algorithm below later - // Iterates over the uses of instruction and finds their basic blocks and annotates them + // Iterates over the uses of instruction and finds their basic blocks + // and annotates them for (Value::use_iterator useIter = inst->use_begin(); useIter != inst->use_end(); useIter++) { - Value* userValue = *useIter; + Value *userValue = *useIter; for (unsigned int f = 0; f < blocks.size(); f++) { - bBlockGraph searchBlock = blocks.at(f); + const bBlockGraph &searchBlock = blocks.at(f); for (unsigned int d = 0; d < searchBlock.instNodes.size(); d++) { - Instruction* targetInst = searchBlock.instNodes.at(d).raw; + Instruction *targetInst = searchBlock.instNodes.at(d).raw; if (userValue == targetInst) { instNode targetNode = searchBlock.instNodes.at(d); outfs << nodeName << " -> " << targetNode.name; @@ -162,14 +166,14 @@ bool llfiDotGraph::runOnFunction(Function &F) { bBlockGraph block = blocks.at(i); block.writeToStream(outfs); if (block.exitInst->getOpcode() == Instruction::Br) { - BranchInst* exitInst = (BranchInst*)block.exitInst; + BranchInst *exitInst = dyn_cast(block.exitInst); for (unsigned int s = 0; s < exitInst->getNumSuccessors(); s++) { - BasicBlock* succ = exitInst->getSuccessor(s); + BasicBlock *succ = exitInst->getSuccessor(s); for (unsigned int d = 0; d < blocks.size(); d++) { if (blocks.at(d).raw == succ) { std::string from = block.instNodes.back().name; std::string to = blocks.at(d).instNodes.front().name; - outfs << from << " -> " << to << ";\n"; + outfs << from << " -> " << to << ";\n"; } } } @@ -179,9 +183,10 @@ bool llfiDotGraph::runOnFunction(Function &F) { return false; } -//Register the pass with the llvm +// Register the pass with the llvm char llfiDotGraph::ID = 0; -static RegisterPass X("dotgraphpass", - "Outputs a dot graph of instruction execution at runtime", false, false); +static RegisterPass + X("dotgraphpass", "Outputs a dot graph of instruction execution at runtime", + false, false); -} +} // namespace llfi diff --git a/llvm_passes/core/LLFIDotGraphPass.h b/llvm_passes/core/LLFIDotGraphPass.h index f3e9d12f..552d3a7b 100644 --- a/llvm_passes/core/LLFIDotGraphPass.h +++ b/llvm_passes/core/LLFIDotGraphPass.h @@ -1,78 +1,85 @@ -#include -#include +#ifndef LLFI_DOT_GRAPH_PASS_H +#define LLFI_DOT_GRAPH_PASS_H + #include "llvm/Support/raw_os_ostream.h" -namespace llfi{ +#include +#include - struct llfiDotGraph : public FunctionPass { - static char ID; - std::ofstream outfs; - llfiDotGraph() : FunctionPass(ID) {} +namespace llfi { - virtual bool doInitialization(Module &M) { - outfs.open("llfi.stat.graph.dot", std::ios::trunc); - outfs << "digraph \"LLFI Program Graph\" {\n"; +struct llfiDotGraph : public FunctionPass { + static char ID; + std::ofstream outfs; + llfiDotGraph() : FunctionPass(ID) {} - return false; - } + bool doInitialization(Module &M) override { + outfs.open("llfi.stat.graph.dot", std::ios::trunc); + outfs << "digraph \"LLFI Program Graph\" {\n"; - virtual bool doFinalization(Module &M) { - outfs << "{ rank = sink;" - "Legend [shape=none, margin=0, label=<" - "" - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - " " - "
Legend
Correct Control Flow solid arrow
Data Dependancy solid arrow
Error Propogation Flowsolid arrow
The Affected Instruction(s) by Fault Injection
The Instruction(s) LLFI Injects Faults to
" - ">];" - "}"; - outfs << "}\n"; - outfs.close(); - return false; - } + return false; + } - virtual bool runOnFunction(Function &F); - }; + bool doFinalization(Module &M) override { + outfs << "{ rank = sink;" + "Legend [shape=none, margin=0, label=<" + "" + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + " " + "
Legend
Correct Control Flow solid arrow
Data Dependancy solid arrow
Error Propogation Flowsolid arrow
The Affected Instruction(s) by Fault Injection
The Instruction(s) LLFI Injects Faults to
" + ">];" + "}"; + outfs << "}\n"; + outfs.close(); + return false; + } - struct NewLLFIDotGraph : llvm::PassInfoMixin { + bool runOnFunction(Function &F) override; +}; - // Without isRequired returning true, this pass will be skipped for functions - // decorated with the optnone LLVM attribute. Note that clang -O0 decorates - // all functions with optnone. - static bool isRequired() { return true; } +struct NewLLFIDotGraph : llvm::PassInfoMixin { - // Main entry point, takes IR unit to run the pass on (&F) and the - // corresponding pass manager (to be queried if need be) - PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { - llfiDotGraph tempObj; - tempObj.doInitialization(M); + // Without isRequired returning true, this pass will be skipped for functions + // decorated with the optnone LLVM attribute. Note that clang -O0 decorates + // all functions with optnone. + static bool isRequired() { return true; } - for (Function &F : M) { - tempObj.runOnFunction(F); - } + // Main entry point, takes IR unit to run the pass on (&F) and the + // corresponding pass manager (to be queried if need be) + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + llfiDotGraph tempObj; + tempObj.doInitialization(M); - tempObj.doFinalization(M); - return PreservedAnalyses::all(); + for (Function &F : M) { + tempObj.runOnFunction(F); } - }; -} + + tempObj.doFinalization(M); + return PreservedAnalyses::all(); + } +}; +} // namespace llfi + +#endif // LLFI_DOT_GRAPH_PASS_H diff --git a/llvm_passes/core/ProfilingPass.cpp b/llvm_passes/core/ProfilingPass.cpp index 7bb97d96..4cc49efc 100644 --- a/llvm_passes/core/ProfilingPass.cpp +++ b/llvm_passes/core/ProfilingPass.cpp @@ -12,38 +12,39 @@ // definition is linked to the instrumented bitcode file (after this pass). //===----------------------------------------------------------------------===// +#include "ProfilingPass.h" + +#include "Controller.h" +#include "Utils.h" + #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/Support/raw_ostream.h" -//BEHROOZ: #include "llvm/Support/CommandLine.h" - +#include "llvm/Support/raw_ostream.h" #include #include #include -#include "ProfilingPass.h" -#include "Controller.h" -#include "Utils.h" - using namespace llvm; namespace llfi { -char LegacyProfilingPass::ID=0; -extern cl::opt< std::string > llfilogfile; +char LegacyProfilingPass::ID = 0; +extern cl::opt llfilogfile; // Flag to enable/disable output of FI statistics for ML applications in the // llfi.stat.fi.injectedfaults.txt file. // Enabling this option will cause LLTFI to output the layer type and number in // which the fault is injected. This option is disabled by default. -static cl::opt< bool > mlfistats("mlfistats", - cl::desc("Flag to disable or enable the FI statistics of ML applications. \ - Default value: false."), cl::init(false)); +static cl::opt mlfistats( + "mlfistats", + cl::desc("Flag to disable or enable the FI statistics of ML applications. \ + Default value: false."), + cl::init(false)); // Find all the call to OMInstrumentPoint function and insert a call to // lltfiMLLayer function before each call to OMInstrumentPoint. @@ -53,33 +54,37 @@ void insertCallForMLFIStats(Module &M) { // Find main_graph function in this module. Function *main_graph = M.getFunction("main_graph"); + if (!main_graph) + return; // Iterate over all instructions of the main_graph function. - for (Function::iterator bb = main_graph->begin(); bb != main_graph->end(); bb++) { + for (Function::iterator bb = main_graph->begin(); bb != main_graph->end(); + bb++) { for (BasicBlock::iterator inst = bb->begin(); inst != bb->end(); inst++) { // If the instruction is a call instruction, check if it is a call to // the OMInstrumentPoint function. if (isa(inst)) { - CallInst *call_inst = dyn_cast(inst); - if (call_inst->getCalledFunction()->getName() == "OMInstrumentPoint") { + CallInst *call_inst = cast(inst); + if (call_inst->getCalledFunction() && + call_inst->getCalledFunction()->getName() == "OMInstrumentPoint") { // Clone the instruction and reassign the operands. - Instruction* duplicatedInst = inst->clone(); - for(unsigned int i = 0; i < duplicatedInst->getNumOperands(); i++){ - duplicatedInst->setOperand(i, inst->getOperand(i)); + Instruction *duplicatedInst = inst->clone(); + for (unsigned int i = 0; i < duplicatedInst->getNumOperands(); i++) { + duplicatedInst->setOperand(i, inst->getOperand(i)); } auto Fn = inst->getFunction()->getParent()->getOrInsertFunction( - "lltfiMLLayer", Type::getVoidTy(inst->getContext()), - Type::getInt64Ty(inst->getContext()), - Type::getInt64Ty(inst->getContext())); + "lltfiMLLayer", Type::getVoidTy(inst->getContext()), + Type::getInt64Ty(inst->getContext()), + Type::getInt64Ty(inst->getContext())); // Change name of the duplicate call instruction. - CallInst *duplicateCall = dyn_cast(duplicatedInst); + CallInst *duplicateCall = cast(duplicatedInst); duplicateCall->setCalledFunction(Fn); // Insert the duplicate instruction - duplicatedInst->insertBefore(inst->getNextNode()); + duplicatedInst->insertBefore(inst->getNextNode()->getIterator()); } } } @@ -87,62 +92,70 @@ void insertCallForMLFIStats(Module &M) { } bool LegacyProfilingPass::runOnModule(Module &M) { - LLVMContext &context = M.getContext(); + LLVMContext &context = M.getContext(); - std::map* > *fi_inst_regs_map; + std::map *> *fi_inst_regs_map = nullptr; Controller *ctrl = Controller::getInstance(M); ctrl->getFIInstRegsMap(&fi_inst_regs_map); - //BEHROOZ: std::error_code err; raw_fd_ostream logFile(llfilogfile.c_str(), err, sys::fs::OF_Append); - for (std::map* >::const_iterator - inst_reg_it = fi_inst_regs_map->begin(); + for (std::map *>::const_iterator inst_reg_it = + fi_inst_regs_map->begin(); inst_reg_it != fi_inst_regs_map->end(); ++inst_reg_it) { Instruction *fi_inst = inst_reg_it->first; - std::list *fi_regs = inst_reg_it->second; + std::list *fi_regs = inst_reg_it->second; - /*BEHROOZ: This section makes sure that we do not instrument the intrinsic functions*/ - if(isa(fi_inst)){ - bool continue_flag=false; + // Skip intrinsic functions to avoid invalid instrumentation + if (isa(fi_inst)) { + bool continue_flag = false; for (std::list::iterator reg_pos_it_mem = fi_regs->begin(); - (reg_pos_it_mem != fi_regs->end()) && (*reg_pos_it_mem != DST_REG_POS); ++reg_pos_it_mem) { + (reg_pos_it_mem != fi_regs->end()) && + (*reg_pos_it_mem != DST_REG_POS); + ++reg_pos_it_mem) { std::string reg_mem = fi_inst->getOperand(*reg_pos_it_mem)->getName().str(); - if ((reg_mem.find("memcpy") != std::string::npos) || (reg_mem.find("memset") != std::string::npos) || (reg_mem.find("expect") != std::string::npos) || (reg_mem.find("memmove") != std::string::npos)){ - logFile << "LLFI cannot instrument " << reg_mem << " intrinsic function"<< "\n"; - continue_flag=true; + if ((reg_mem.find("memcpy") != std::string::npos) || + (reg_mem.find("memset") != std::string::npos) || + (reg_mem.find("expect") != std::string::npos) || + (reg_mem.find("memmove") != std::string::npos)) { + logFile << "LLFI cannot instrument " << reg_mem + << " intrinsic function" << "\n"; + continue_flag = true; break; } } - if(continue_flag) + if (continue_flag) continue; } - /*BEHROOZ: This is to make sure we do not instrument landingpad instructions.*/ + // Skip landingpad instructions which cannot be instrumented std::string current_opcode = fi_inst->getOpcodeName(); - if(current_opcode.find("landingpad") != std::string::npos){ - logFile << "LLFI cannot instrument " << current_opcode << " instruction" << "\n"; + if (current_opcode.find("landingpad") != std::string::npos) { + logFile << "LLFI cannot instrument " << current_opcode << " instruction" + << "\n"; continue; } - Value *fi_reg = *(fi_regs->begin())==DST_REG_POS ? fi_inst : (fi_inst->getOperand(*(fi_regs->begin()))); + Value *fi_reg = *(fi_regs->begin()) == DST_REG_POS + ? fi_inst + : (fi_inst->getOperand(*(fi_regs->begin()))); Instruction *insertptr = getInsertPtrforRegsofInst(fi_reg, fi_inst); // function declaration FunctionCallee profilingfunc = getLLFILibProfilingFunc(M); // prepare for the calling argument and call the profiling function - std::vector profilingarg(1); - const IntegerType* itype = IntegerType::get(context, 32); + std::vector profilingarg(1); + const IntegerType *itype = IntegerType::get(context, 32); - //LLVM 3.3 Upgrading - IntegerType* itype_non_const = const_cast(itype); - Value* opcode = ConstantInt::get(itype_non_const, fi_inst->getOpcode()); + // LLVM 3.3 Upgrading + IntegerType *itype_non_const = const_cast(itype); + Value *opcode = ConstantInt::get(itype_non_const, fi_inst->getOpcode()); profilingarg[0] = opcode; - ArrayRef profilingarg_array_ref(profilingarg); + ArrayRef profilingarg_array_ref(profilingarg); - CallInst::Create(profilingfunc, profilingarg_array_ref, - "", insertptr); + CallInst::Create(profilingfunc, profilingarg_array_ref, "", + insertptr->getIterator()); } logFile.close(); @@ -155,49 +168,48 @@ bool LegacyProfilingPass::runOnModule(Module &M) { } void LegacyProfilingPass::addEndProfilingFuncCall(Module &M) { - Function* mainfunc = M.getFunction("main"); - if (mainfunc != NULL) { + Function *mainfunc = M.getFunction("main"); + if (mainfunc != nullptr) { FunctionCallee endprofilefunc = getLLFILibEndProfilingFunc(M); // function call - std::set exitinsts; + std::set exitinsts; getProgramExitInsts(M, exitinsts); - assert (exitinsts.size() != 0 - && "Program does not have explicit exit point"); + assert(!exitinsts.empty() && "Program does not have explicit exit point"); - for (std::set::iterator it = exitinsts.begin(); + for (std::set::iterator it = exitinsts.begin(); it != exitinsts.end(); ++it) { Instruction *term = *it; - CallInst::Create(endprofilefunc, "", term); + CallInst::Create(endprofilefunc, "", term->getIterator()); } } } FunctionCallee LegacyProfilingPass::getLLFILibProfilingFunc(Module &M) { LLVMContext &context = M.getContext(); - std::vector paramtypes(1); + std::vector paramtypes(1); paramtypes[0] = Type::getInt32Ty(context); // LLVM 3.3 Upgrading - ArrayRef paramtypes_array_ref(paramtypes); + ArrayRef paramtypes_array_ref(paramtypes); - FunctionType* profilingfunctype = FunctionType::get( - Type::getVoidTy(context), paramtypes_array_ref, false); + FunctionType *profilingfunctype = + FunctionType::get(Type::getVoidTy(context), paramtypes_array_ref, false); FunctionCallee profilingfunc = M.getOrInsertFunction("doProfiling", profilingfunctype); return profilingfunc; } FunctionCallee LegacyProfilingPass::getLLFILibEndProfilingFunc(Module &M) { - LLVMContext& context = M.getContext(); - FunctionType* endprofilingfunctype = FunctionType::get( - Type::getVoidTy(context), false); + LLVMContext &context = M.getContext(); + FunctionType *endprofilingfunctype = + FunctionType::get(Type::getVoidTy(context), false); FunctionCallee endprofilefunc = M.getOrInsertFunction("endProfiling", endprofilingfunctype); return endprofilefunc; } // Registration for the old PM -static RegisterPass X("profilingpass", - "Profiling pass", false, false); -} +static RegisterPass X("profilingpass", "Profiling pass", + false, false); +} // namespace llfi diff --git a/llvm_passes/core/ProfilingPass.h b/llvm_passes/core/ProfilingPass.h index aacc91bf..928c269e 100644 --- a/llvm_passes/core/ProfilingPass.h +++ b/llvm_passes/core/ProfilingPass.h @@ -1,12 +1,12 @@ -//This pass is run after the transform pass for inserting hooks -//for fault injection +// This pass is run after the transform pass for inserting hooks +// for fault injection #ifndef PROFILING_PASS_H #define PROFILING_PASS_H #include "llvm/IR/Constants.h" -#include "llvm/Pass.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/PassPlugin.h" @@ -16,37 +16,35 @@ using namespace llvm; namespace llfi { - // For legacy PM - class LegacyProfilingPass: public ModulePass { - public: - LegacyProfilingPass() : ModulePass(ID) {} - virtual bool runOnModule(Module &M); - static char ID; - - private: - void addEndProfilingFuncCall(Module &M); - private: - FunctionCallee getLLFILibProfilingFunc(Module &M); - FunctionCallee getLLFILibEndProfilingFunc(Module &M); - }; - - // For new PM - struct ProfilingPass: llvm::PassInfoMixin { - llvm::PreservedAnalyses run(llvm::Module &M, - llvm::ModuleAnalysisManager &){ - - auto obj = new LegacyProfilingPass(); - bool isChanged = obj->runOnModule(M); - - delete obj; - return (isChanged) ? llvm::PreservedAnalyses::none(): - llvm::PreservedAnalyses::all(); - } - - // Without isRequired returning true, this pass will be skipped for functions - // decorated with the optnone LLVM attribute. Note that clang -O0 decorates - // all functions with optnone. - static bool isRequired() { return true; } - }; -} +// For legacy PM +class LegacyProfilingPass : public ModulePass { +public: + LegacyProfilingPass() : ModulePass(ID) {} + bool runOnModule(Module &M) override; + static char ID; + +private: + void addEndProfilingFuncCall(Module &M); + +private: + FunctionCallee getLLFILibProfilingFunc(Module &M); + FunctionCallee getLLFILibEndProfilingFunc(Module &M); +}; + +// For new PM +struct ProfilingPass : llvm::PassInfoMixin { + llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &) { + + LegacyProfilingPass obj; + bool isChanged = obj.runOnModule(M); + return (isChanged) ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all(); + } + + // Without isRequired returning true, this pass will be skipped for functions + // decorated with the optnone LLVM attribute. Note that clang -O0 decorates + // all functions with optnone. + static bool isRequired() { return true; } +}; +} // namespace llfi #endif diff --git a/llvm_passes/core/RegLocBasedFIRegSelector.cpp b/llvm_passes/core/RegLocBasedFIRegSelector.cpp index 973e9467..c8b32b4a 100644 --- a/llvm_passes/core/RegLocBasedFIRegSelector.cpp +++ b/llvm_passes/core/RegLocBasedFIRegSelector.cpp @@ -2,26 +2,30 @@ namespace llfi { -bool RegLocBasedFIRegSelector::isRegofInstFITarget(Value *reg, - Instruction *inst) { +bool RegLocBasedFIRegSelector::isRegofInstFITarget(Value *reg, + Instruction *inst) { if (firegloc == dstreg) { return reg == inst; } else if (firegloc == allsrcreg) { - if(isa(inst)){ - if(inst->getOperand(inst->getNumOperands()-1) == reg && isa(reg)) return false; + if (isa(inst)) { + if (inst->getOperand(inst->getNumOperands() - 1) == reg && + isa(reg)) + return false; } // Switch case values must remain constants (LLVM IR constraint); skip them - if(isa(inst) && isa(reg)) return false; + if (isa(inst) && isa(reg)) + return false; return reg != inst; } else if (firegloc == allreg) { - // dbgs() << "Choosing all regs" << "\n"; - return true; + // dbgs() << "Choosing all regs" << "\n"; + return true; } else { - unsigned srcindex = (unsigned) (firegloc - srcreg1); + unsigned srcindex = (unsigned)(firegloc - srcreg1); unsigned totalsrcregnum = inst->getNumOperands(); if (srcindex < totalsrcregnum) { - if(isa(inst)){ - if(inst->getOperand(totalsrcregnum-1) == reg && isa(reg)) return false; + if (isa(inst)) { + if (inst->getOperand(totalsrcregnum - 1) == reg && isa(reg)) + return false; } return inst->getOperand(srcindex) == reg; } else @@ -29,14 +33,13 @@ bool RegLocBasedFIRegSelector::isRegofInstFITarget(Value *reg, } } -bool RegLocBasedFIRegSelector::isRegofInstFITarget(Value *reg, - Instruction *inst, - int pos) { - bool result = isRegofInstFITarget(reg, inst); +bool RegLocBasedFIRegSelector::isRegofInstFITarget(Value *reg, + Instruction *inst, int pos) { + bool result = isRegofInstFITarget(reg, inst); // Only check position if it's not allsrcreg, dstreg or all reg - if (! (firegloc == allsrcreg || firegloc == dstreg || firegloc == allreg) ) - result = result && (firegloc - srcreg1) == pos; + if (!(firegloc == allsrcreg || firegloc == dstreg || firegloc == allreg)) + result = result && (firegloc - srcreg1) == pos; return result; } -} +} // namespace llfi diff --git a/llvm_passes/core/RegLocBasedFIRegSelector.h b/llvm_passes/core/RegLocBasedFIRegSelector.h index fd75d3e0..a5c18577 100644 --- a/llvm_passes/core/RegLocBasedFIRegSelector.h +++ b/llvm_passes/core/RegLocBasedFIRegSelector.h @@ -4,18 +4,17 @@ #include "Controller.h" #include "FIRegSelector.h" namespace llfi { -class RegLocBasedFIRegSelector: public HardwareFIRegSelector { - public: - RegLocBasedFIRegSelector(FIRegLoc filoc): firegloc(filoc) {} +class RegLocBasedFIRegSelector : public HardwareFIRegSelector { +public: + RegLocBasedFIRegSelector(FIRegLoc filoc) : firegloc(filoc) {} - private: - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst); - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst, int pos); - private: +private: + bool isRegofInstFITarget(Value *reg, Instruction *inst) override; + bool isRegofInstFITarget(Value *reg, Instruction *inst, int pos) override; + +private: FIRegLoc firegloc; }; -} - - +} // namespace llfi #endif diff --git a/llvm_passes/core/Utils.cpp b/llvm_passes/core/Utils.cpp index 74b6b9ce..cea51a84 100644 --- a/llvm_passes/core/Utils.cpp +++ b/llvm_passes/core/Utils.cpp @@ -4,12 +4,11 @@ namespace llfi { std::string demangleFuncName(std::string func) { - std::string ret = func; + std::string ret = func; // Check for name mangling. C++ functions will always start with _Z // Demangled form is processed to remove type information. - if(func.length() >= 2 && (func[0] == '_' && func[1] == 'Z')) { - int stat; - char *test = itaniumDemangle(func.c_str(), NULL, NULL, &stat); + if (func.length() >= 2 && (func[0] == '_' && func[1] == 'Z')) { + char *test = itaniumDemangle(func); // Check if the demangeled function name is null or not. // Thanks Allesio for bringing this up. @@ -22,7 +21,7 @@ std::string demangleFuncName(std::string func) { // Templated functions will have type information first, so skip to the // first space. size_t startpos = demangled.find(" "); - if(startpos < endpos) { + if (startpos < endpos) { // skip until after the space ++startpos; // also modify endpos to the first '<' to remove template info @@ -32,7 +31,7 @@ std::string demangleFuncName(std::string func) { startpos = 0; } - ret = demangled.substr(startpos,endpos); + ret = demangled.substr(startpos, endpos); } return ret; @@ -56,25 +55,26 @@ Instruction *getTermInstofFunction(Function *func) { assert(isa(ret) || isa(ret) || isa(ret) && - "Last instruction is not return or resume or exit() instruction"); + "Last instruction is not return or resume or exit() instruction"); return ret; } -void getAllTermInstofFunction(Function *func, std::set &exitinsts) { - - for (auto i = inst_begin(func); i != inst_end(func); i++) { - Instruction* ret = &*i; +void getAllTermInstofFunction(Function *func, + std::set &exitinsts) { + + for (auto i = inst_begin(func); i != inst_end(func); i++) { + Instruction *ret = &*i; if (isa(ret) || isa(ret) || - isa(ret)) - exitinsts.insert(ret); + isa(ret)) + exitinsts.insert(ret); } } -void getProgramExitInsts(Module &M, std::set &exitinsts) { +void getProgramExitInsts(Module &M, std::set &exitinsts) { for (Module::iterator m_it = M.begin(); m_it != M.end(); ++m_it) { if (!m_it->isDeclaration()) { - //m_it is a function + // m_it is a function for (inst_iterator f_it = inst_begin(&*m_it); f_it != inst_end(&*m_it); ++f_it) { Instruction *inst = &(*f_it); @@ -89,7 +89,7 @@ void getProgramExitInsts(Module &M, std::set &exitinsts) { } } - Function* mainfunc = M.getFunction("main"); + Function *mainfunc = M.getFunction("main"); getAllTermInstofFunction(mainfunc, exitinsts); } @@ -100,28 +100,29 @@ Instruction *getInsertPtrforRegsofInst(Value *reg, Instruction *inst) { // inject into destination reg, insert after inst if (inst->isTerminator()) { errs() << "ERROR: LLFI not able to inject into destination register of " - << *inst << ", change isRegofInstInjectable() to fix it\n"; + << *inst << ", change isRegofInstInjectable() to fix it\n"; exit(2); } else { BasicBlock::iterator bb_it(inst); - while (isa(++bb_it)) ; + while (isa(++bb_it)) + ; return &*bb_it; } } else { // Assume the reg is the src of inst, insert before inst if (isa(inst)) { - errs() << "ERROR: LLFI not able to inject into source register of " << - *inst << ", change isRegofInstInjectable to fix it\n"; + errs() << "ERROR: LLFI not able to inject into source register of " + << *inst << ", change isRegofInstInjectable to fix it\n"; exit(2); } return inst; } } -Instruction* changeInsertPtrIfInjectFaultInst(Instruction *inst) { +Instruction *changeInsertPtrIfInjectFaultInst(Instruction *inst) { MDNode *mdnode = inst->getMetadata("llfi_injectfault"); if (mdnode) { - if (((MDString *)mdnode->getOperand(0).get())->getString() == "after") { + if (cast(mdnode->getOperand(0).get())->getString() == "after") { return inst->getNextNonDebugInstruction(); } else { return inst; @@ -135,7 +136,7 @@ void setInjectFaultInst(Value *reg, Instruction *inst, Instruction *ficall) { Function *func = inst->getParent()->getParent(); LLVMContext &context = func->getContext(); - MDString *s; + MDString *s = nullptr; if (reg == inst) { s = MDString::get(context, "after"); } else { @@ -149,20 +150,19 @@ void setInjectFaultInst(Value *reg, Instruction *inst, Instruction *ficall) { long getLLFIIndexofInst(Instruction *inst) { MDNode *mdnode = inst->getMetadata("llfi_index"); if (mdnode) { - Constant *cns = - dyn_cast(mdnode->getOperand(0))->getValue(); - ConstantInt *cns_index = dyn_cast(cns); + Constant *cns = cast(mdnode->getOperand(0))->getValue(); + ConstantInt *cns_index = cast(cns); return cns_index->getSExtValue(); } else { errs() << "ERROR: LLFI indices for instructions are required for the pass, " - << "please run genllfiindexpass first\n"; + << "please run genllfiindexpass first\n"; exit(3); } } static long fi_index = 1; void setLLFIIndexofInst(Instruction *inst) { - assert (fi_index >= 0 && "static instruction number exceeds index max"); + assert(fi_index >= 0 && "static instruction number exceeds index max"); Function *func = inst->getParent()->getParent(); LLVMContext &context = func->getContext(); std::vector llfiindex(1); @@ -173,48 +173,36 @@ void setLLFIIndexofInst(Instruction *inst) { inst->setMetadata("llfi_index", mdnode); } -void genFullNameOpcodeMap( - std::map &opcodenamemap) { -#define HANDLE_INST(N, OPC, CLASS) \ +void genFullNameOpcodeMap(std::map &opcodenamemap) { +#define HANDLE_INST(N, OPC, CLASS) \ opcodenamemap[std::string(Instruction::getOpcodeName(N))] = N; #include "llvm/IR/Instruction.def" } -//Returns true if the function is indexed by llfi +// Returns true if the function is indexed by llfi //(and therefore we should perform trace/fault injects on it) bool isLLFIIndexedInst(Instruction *inst) { MDNode *mdnode = inst->getMetadata("llfi_index"); if (mdnode) { return true; } else { - return false; - } + return false; + } } -//======== Add opcode_str QINING @SEP 13th======== -GlobalVariable* findOrCreateGlobalNameString(Module &M, std::string name) -{ - LLVMContext& context = M.getContext(); - std::string str_suffix = std::string("_namestr"); - GlobalVariable* nameStr = M.getGlobalVariable(name+str_suffix, true); - if(nameStr != NULL) - { - //found - return nameStr; - } - //not found - //dbgs()<<"\t\t\tNeed new name_str: "<getType(), true, GlobalVariable::InternalLinkage, name_c, gv_nameStr.c_str()); - //add to global list - M.getGlobalList().push_back(nameStr); - return nameStr; +GlobalVariable *findOrCreateGlobalNameString(Module &M, std::string name) { + LLVMContext &context = M.getContext(); + std::string str_suffix = std::string("_namestr"); + GlobalVariable *nameStr = M.getGlobalVariable(name + str_suffix, true); + if (nameStr != nullptr) { + return nameStr; + } + std::string gv_nameStr = name + str_suffix; + Constant *name_c = ConstantDataArray::getString(context, name); + nameStr = new GlobalVariable(M, name_c->getType(), true, + GlobalVariable::InternalLinkage, name_c, + gv_nameStr.c_str()); + return nameStr; } -//================================================ -} +} // namespace llfi diff --git a/llvm_passes/core/Utils.h b/llvm_passes/core/Utils.h index b421c0fe..0da6855d 100644 --- a/llvm_passes/core/Utils.h +++ b/llvm_passes/core/Utils.h @@ -2,26 +2,24 @@ #define LLFI_UTILS_H #include "llvm/Demangle/Demangle.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" -#include "llvm/IR/Value.h" -#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/Constants.h" - -#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" // For name demangling #include - #include #include -#include #include +#include using namespace llvm; namespace llfi { @@ -37,7 +35,7 @@ Instruction *getTermInstofFunction(Function *func); // return instumentation code insertion point for fi in reg of inst Instruction *getInsertPtrforRegsofInst(Value *reg, Instruction *inst); -void getProgramExitInsts(Module &M, std::set &exitinsts); +void getProgramExitInsts(Module &M, std::set &exitinsts); // get or set the LLFI index of the specified instruction. use metadata long getLLFIIndexofInst(Instruction *inst); @@ -46,19 +44,18 @@ void setLLFIIndexofInst(Instruction *inst); // get the map of opcode name and their opcode void genFullNameOpcodeMap(std::map &opcodenamemap); -//Check metadata to see if instruction was generated/inserted by LLFI +// Check metadata to see if instruction was generated/inserted by LLFI bool isLLFIIndexedInst(Instruction *inst); // sets the metadata on the injectFault call void setInjectFaultInst(Value *reg, Instruction *inst, Instruction *ficall); -// checks if the instruction is a call to llfi's 'injectFault*', if it is, return -// the next instruction iff injectFault occurs AFTER the targeted instruction -Instruction* changeInsertPtrIfInjectFaultInst(Instruction *inst); +// checks if the instruction is a call to llfi's 'injectFault*', if it is, +// return the next instruction iff injectFault occurs AFTER the targeted +// instruction +Instruction *changeInsertPtrIfInjectFaultInst(Instruction *inst); -//======== Add opcode_str QINING @SEP 13th======== -GlobalVariable* findOrCreateGlobalNameString(Module &M, std::string name); -//================================================ -} +GlobalVariable *findOrCreateGlobalNameString(Module &M, std::string name); +} // namespace llfi #endif diff --git a/llvm_passes/hardware_failures/FuncNameFIInstSelector.cpp b/llvm_passes/hardware_failures/FuncNameFIInstSelector.cpp index 34d68183..445bf3a4 100644 --- a/llvm_passes/hardware_failures/FuncNameFIInstSelector.cpp +++ b/llvm_passes/hardware_failures/FuncNameFIInstSelector.cpp @@ -1,9 +1,10 @@ -#include "llvm/IR/Instructions.h" - #include "FuncNameFIInstSelector.h" + #include "Utils.h" +#include "llvm/IR/Instructions.h" + namespace llfi { bool FuncNameFIInstSelector::isInstFITarget(Instruction *inst) { @@ -16,4 +17,4 @@ bool FuncNameFIInstSelector::isInstFITarget(Instruction *inst) { return false; } -} +} // namespace llfi diff --git a/llvm_passes/hardware_failures/FuncNameFIInstSelector.h b/llvm_passes/hardware_failures/FuncNameFIInstSelector.h index 61494013..7b734f5a 100644 --- a/llvm_passes/hardware_failures/FuncNameFIInstSelector.h +++ b/llvm_passes/hardware_failures/FuncNameFIInstSelector.h @@ -1,42 +1,38 @@ #ifndef FUNC_NAME_FI_INST_SELECTOR_H #define FUNC_NAME_FI_INST_SELECTOR_H +#include "FIInstSelector.h" + #include #include -#include "FIInstSelector.h" - using namespace llvm; namespace llfi { -class FuncNameFIInstSelector: public HardwareFIInstSelector { - public: +class FuncNameFIInstSelector : public HardwareFIInstSelector { +public: FuncNameFIInstSelector(std::set *funclist) { this->funclist = funclist; } - FuncNameFIInstSelector() { - delete funclist; - } - virtual void getCompileTimeInfo(std::map& info){ + FuncNameFIInstSelector() { delete funclist; } + void getCompileTimeInfo(std::map &info) override { info["failure_class"] = "HardwareFault"; info["failure_mode"] = "SpecifiedFunctions"; - for(std::set::iterator SI = funclist->begin(); - SI != funclist->end(); SI++){ + for (std::set::iterator SI = funclist->begin(); + SI != funclist->end(); SI++) { info["targets"] += *SI + "()/"; } - //remove the '/' at the end - info["targets"] = info["targets"].substr(0, info["targets"].length()-1); + // remove the '/' at the end + info["targets"] = info["targets"].substr(0, info["targets"].length() - 1); info["injector"] = ""; } - private: - virtual bool isInstFITarget(Instruction* inst); +private: + bool isInstFITarget(Instruction *inst) override; - private: +private: std::set *funclist; }; -} - - +} // namespace llfi #endif diff --git a/llvm_passes/hardware_failures/InstTypeFIInstSelector.cpp b/llvm_passes/hardware_failures/InstTypeFIInstSelector.cpp index e15bc422..034e48f6 100644 --- a/llvm_passes/hardware_failures/InstTypeFIInstSelector.cpp +++ b/llvm_passes/hardware_failures/InstTypeFIInstSelector.cpp @@ -1,7 +1,6 @@ -#include "llvm/IR/Instructions.h" - #include "InstTypeFIInstSelector.h" +#include "llvm/IR/Instructions.h" namespace llfi { bool InstTypeFIInstSelector::isInstFITarget(Instruction *inst) { @@ -12,4 +11,4 @@ bool InstTypeFIInstSelector::isInstFITarget(Instruction *inst) { return false; } -} +} // namespace llfi diff --git a/llvm_passes/hardware_failures/InstTypeFIInstSelector.h b/llvm_passes/hardware_failures/InstTypeFIInstSelector.h index 33358784..7aa5ae07 100644 --- a/llvm_passes/hardware_failures/InstTypeFIInstSelector.h +++ b/llvm_passes/hardware_failures/InstTypeFIInstSelector.h @@ -1,35 +1,32 @@ #ifndef INST_TYPE_FI_INST_SELECTOR_H #define INST_TYPE_FI_INST_SELECTOR_H -#include - #include "FIInstSelector.h" +#include + using namespace llvm; namespace llfi { -class InstTypeFIInstSelector: public HardwareFIInstSelector { - public: +class InstTypeFIInstSelector : public HardwareFIInstSelector { +public: InstTypeFIInstSelector(std::set *opcodelist) { this->opcodelist = opcodelist; } - ~InstTypeFIInstSelector() { - delete opcodelist; - } - virtual void getCompileTimeInfo(std::map& info){ + ~InstTypeFIInstSelector() override { delete opcodelist; } + void getCompileTimeInfo(std::map &info) override { info["failure_class"] = "HardwareFault"; info["failure_mode"] = "SpecifiedInstructionTypes"; info["targets"] = ""; info["injector"] = ""; } - private: - virtual bool isInstFITarget(Instruction* inst); - private: +private: + bool isInstFITarget(Instruction *inst) override; + +private: std::set *opcodelist; }; -} - - +} // namespace llfi #endif diff --git a/llvm_passes/hardware_failures/LLFIIndexFIInstSelector.cpp b/llvm_passes/hardware_failures/LLFIIndexFIInstSelector.cpp index b25b1740..f1190a2b 100644 --- a/llvm_passes/hardware_failures/LLFIIndexFIInstSelector.cpp +++ b/llvm_passes/hardware_failures/LLFIIndexFIInstSelector.cpp @@ -1,32 +1,34 @@ -#include "llvm/IR/Instructions.h" -#include "llvm/Support/CommandLine.h" - -#include "FIInstSelector.h" #include "FICustomSelectorManager.h" +#include "FIInstSelector.h" #include "Utils.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/CommandLine.h" + using namespace llvm; namespace llfi { -static cl::list< std::string > injecttoindex("injecttoindex", - cl::desc("Inject into the specified LLFI index"), - cl::ZeroOrMore); +static cl::list + injecttoindex("injecttoindex", + cl::desc("Inject into the specified LLFI index"), + cl::ZeroOrMore); /** * LLFI Index instruction selector selects instruction of certain indices */ -class LLFIIndexFIInstSelector: public HardwareFIInstSelector { - private: - virtual bool isInstFITarget(Instruction *inst) { +class LLFIIndexFIInstSelector : public HardwareFIInstSelector { +private: + bool isInstFITarget(Instruction *inst) override { long llfiindex = getLLFIIndexofInst(inst); for (unsigned i = 0; i != injecttoindex.size(); ++i) if (atol(injecttoindex[i].c_str()) == llfiindex) return true; return false; } - public: - virtual void getCompileTimeInfo(std::map& info){ + +public: + void getCompileTimeInfo(std::map &info) override { info["failure_class"] = "HardwareFault"; info["failure_mode"] = "SpecifiedLLFIIndex"; info["targets"] = ""; @@ -35,4 +37,4 @@ class LLFIIndexFIInstSelector: public HardwareFIInstSelector { }; static RegisterFIInstSelector X("llfiindex", new LLFIIndexFIInstSelector()); -} +} // namespace llfi diff --git a/llvm_passes/instruction_duplication/InstructionDuplication.cpp b/llvm_passes/instruction_duplication/InstructionDuplication.cpp index 438350de..b605b05a 100644 --- a/llvm_passes/instruction_duplication/InstructionDuplication.cpp +++ b/llvm_passes/instruction_duplication/InstructionDuplication.cpp @@ -1,518 +1,553 @@ -#define DEBUG_TYPE "InstructionDuplicationPass" - #include "FICustomSelectorManager.h" -#include "Utils.h" -#include "FIInstSelectorManager.h" #include "FIInstSelector.h" -#include "InstTypeFIInstSelector.h" -#include "FuncNameFIInstSelector.h" +#include "FIInstSelectorManager.h" #include "FIRegSelector.h" +#include "InstTypeFIInstSelector.h" #include "RegLocBasedFIRegSelector.h" +#include "Utils.h" -#include "llvm/Pass.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" #include "llvm/Support/CommandLine.h" -#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include +#include +#include #include #include -#include #include -#include -#include #include -#include +#include + +#include "FuncNameFIInstSelector.h" using namespace llvm; using namespace std; namespace SID { - static cl::opt< string > llfiIndex("llfiIndex", cl::desc(" \ +static cl::opt llfiIndex("llfiIndex", cl::desc(" \ llfiIndex of the arithmetic instruction to duplicate. Default:all"), - cl::init("all")); + cl::init("all")); - static cl::opt< string > layerName("operatorName", cl::desc("Name of \ +static cl::opt layerName("operatorName", cl::desc("Name of \ operator(s) to duplicate. Semi-colon seperated values.\ - Example: conv;relu;matmul;maxpool;all"), cl::init("all")); + Example: conv;relu;matmul;maxpool;all"), + cl::init("all")); - static cl::opt< bool > enableChainDuplication("enableChainDuplication", - cl::desc("Boolean value to indicate whether to do arithmetic chain \ - duplication or not. Default: False"), cl::init(false)); +static cl::opt enableChainDuplication( + "enableChainDuplication", + cl::desc("Boolean value to indicate whether to do arithmetic chain \ + duplication or not. Default: False"), + cl::init(false)); - // Return an array our of string of comma-seperated values. - vector getCommaSeperateVals(string inp) { +// Return an array our of string of comma-seperated values. +vector getCommaSeperateVals(string inp) { - string s = inp; - string delimiter = ";"; - vector retval; - size_t pos = 0; + string s = inp; + string delimiter = ";"; + vector retval; + size_t pos = 0; - string token; - while ((pos = s.find(delimiter)) != string::npos) { - token = s.substr(0, pos); - retval.push_back(token); - s.erase(0, pos + delimiter.length()); - } + string token; + while ((pos = s.find(delimiter)) != string::npos) { + token = s.substr(0, pos); + retval.push_back(token); + s.erase(0, pos + delimiter.length()); + } - if ((pos = s.find(delimiter)) == string::npos) { - retval.push_back(s); - } + retval.push_back(s); - return retval; - } + return retval; +} - // Get unique Id corresponding to the ONNX operator. - static int64_t getOperatorNumber(string name) { +// Get unique Id corresponding to the ONNX operator. +static int64_t getOperatorNumber(string name) { - char opname[100]; - transform(name.begin(), name.end(), name.begin(), - [](unsigned char c){ return tolower(c); }); + char opname[100]; + transform(name.begin(), name.end(), name.begin(), + [](unsigned char c) { return tolower(c); }); - strcpy(opname, name.c_str()); + strncpy(opname, name.c_str(), sizeof(opname) - 1); + opname[sizeof(opname) - 1] = '\0'; - // ONNX assigns unique IDs to each tensor operator. - map ONNXOperatorId = { - {"conv", 1986948931}, - {"relu", 1970038098}, - {"maxpool", 30521821366870349}, - {"matmul", 119251066446157}, - {"add", 6579265}, - {"avgpool", 30521821365761601}, - {"softmax", 33884119937478483} - }; + // ONNX assigns unique IDs to each tensor operator. + map ONNXOperatorId = {{"conv", 1986948931}, + {"relu", 1970038098}, + {"maxpool", 30521821366870349}, + {"matmul", 119251066446157}, + {"add", 6579265}, + {"avgpool", 30521821365761601}, + {"softmax", 33884119937478483}}; - if (ONNXOperatorId.find(opname) == ONNXOperatorId.end()) - return -1; + if (ONNXOperatorId.find(opname) == ONNXOperatorId.end()) + return -1; - return ONNXOperatorId[opname]; - } + return ONNXOperatorId[opname]; +} - // Add Metadata to LLVM instructions; Only for debugging purposes! - void addMetadata(Instruction *ins, char *st = NULL){ - LLVMContext& C = ins->getContext(); - MDNode* N = MDNode::get(C, MDString::get(C, (!st) ? "t" : st)); +// Add Metadata to LLVM instructions; Only for debugging purposes! +void addMetadata(Instruction *ins, const char *st = nullptr) { + LLVMContext &C = ins->getContext(); + MDNode *N = MDNode::get(C, MDString::get(C, (!st) ? "t" : st)); - char finalMD[1000] = "Debug."; - strcat(finalMD, st); - ins->setMetadata(finalMD, N); - } + char finalMD[1000] = "Debug."; + strncat(finalMD, st ? st : "t", sizeof(finalMD) - strlen(finalMD) - 1); + ins->setMetadata(finalMD, N); +} - void printBB(BasicBlock* bb){ +void printBB(BasicBlock *bb) { - errs()<<"------- Printing BB -------------\n"; - for (BasicBlock::const_iterator i = bb->begin(); i != bb->end(); ++i) { + errs() << "------- Printing BB -------------\n"; + for (BasicBlock::const_iterator i = bb->begin(); i != bb->end(); ++i) { - Instruction* inst = const_cast(&*i); - errs()<<*inst<<"\n"; - } - } + Instruction *inst = const_cast(&*i); + errs() << *inst << "\n"; + } +} - void printFunction(Function& F){ - errs()<<"------- Printing Function -------------\n"; +void printFunction(Function &F) { + errs() << "------- Printing Function -------------\n"; - for (BasicBlock& bb : F){ - for (BasicBlock::const_iterator i = bb.begin(); i != bb.end(); ++i) { + for (BasicBlock &bb : F) { + for (BasicBlock::const_iterator i = bb.begin(); i != bb.end(); ++i) { - Instruction* inst = const_cast(&*i); - errs()<<*inst<<"\n"; - } - } + Instruction *inst = const_cast(&*i); + errs() << *inst << "\n"; } + } } +} // namespace SID using namespace SID; -namespace llfi{ - - class InstructionDuplicationPass: public FunctionPass - { - - private: - bool isInitialized; - vector operatorValues; - bool injectInAllOperators; - vector llfiIndexes; - bool injectInAllIndexes; - bool isChainDuplication; - - // Initializes Layer name and granularity of instruction duplication. - void initializeGranularityAndLayerName(string llfiIndex, - string layerName, bool isChainDupl) { - - // Parse operators. - vector OperatorNames = getCommaSeperateVals(layerName); - - for (string name : OperatorNames) { - - if (name.find("all") != string::npos) { - injectInAllOperators = true; - break; - } - - int64_t temp = getOperatorNumber(name); - - if (temp == -1) { - cerr<<"Invalid operator name: "< operatorValues; + bool injectInAllOperators; + vector llfiIndexes; + bool injectInAllIndexes; + bool isChainDuplication; - // Parse LLFIIndexes for FI. - vector LLFIIndexes = getCommaSeperateVals(llfiIndex); + // Initializes Layer name and granularity of instruction duplication. + void initializeGranularityAndLayerName(string llfiIndex, string layerName, + bool isChainDupl) { - for (string index : LLFIIndexes) { + // Parse operators. + vector OperatorNames = getCommaSeperateVals(layerName); - if (index.find("all") != string::npos) { - injectInAllIndexes = true; - break; - } + for (const string &name : OperatorNames) { - int64_t indexLong = atol(index.c_str()); + if (name.find("all") != string::npos) { + injectInAllOperators = true; + break; + } - assert(indexLong > 0 && "Invalid LLFIIndex"); + int64_t temp = getOperatorNumber(name); - llfiIndexes.push_back(indexLong); - } - } + if (temp == -1) { + cerr << "Invalid operator name: " << name << "\n"; + assert(false && "Invalid Operator Name"); + } - public: + operatorValues.push_back(temp); + } - static char ID; + // Parse enableChainDuplication boolean + isChainDuplication = isChainDupl; - InstructionDuplicationPass():FunctionPass(ID) - { - isInitialized = false; - injectInAllIndexes = false; - injectInAllOperators = false; - } + // Parse LLFIIndexes for FI. + vector LLFIIndexes = getCommaSeperateVals(llfiIndex); - // Duplicate a single arithmetic instruction - void duplicateInstruction(Instruction* inst) { + for (const string &index : LLFIIndexes) { - // Clone the instruction and reassign the operands. - Instruction* duplicatedInst = inst->clone(); - for(unsigned int i = 0; i < duplicatedInst->getNumOperands(); i++){ - duplicatedInst->setOperand(i, inst->getOperand(i)); - } + if (index.find("all") != string::npos) { + injectInAllIndexes = true; + break; + } - // Copy metadata - MDNode *mdnode = inst->getMetadata("llfi_index"); - inst->setMetadata("llfi_index", NULL); - duplicatedInst->setMetadata("llfi_index", mdnode); - addMetadata(duplicatedInst, "Duplicated_Instruction"); + int64_t indexLong = atol(index.c_str()); - // Insert the duplicate instruction - duplicatedInst->insertBefore(inst->getNextNode()); + assert(indexLong > 0 && "Invalid LLFIIndex"); - IRBuilder<> IRB(inst->getParent()); - IRB.SetInsertPoint(duplicatedInst->getNextNode()); + llfiIndexes.push_back(indexLong); + } + } - auto Fn = inst->getFunction()->getParent()->getOrInsertFunction( - "compareFloatValues", Type::getFloatTy(inst->getContext()), - Type::getFloatTy(inst->getContext()), - Type::getFloatTy(inst->getContext())); +public: + static char ID; - Value *funret = IRB.CreateCall(Fn, {inst, duplicatedInst}); + InstructionDuplicationPass() : FunctionPass(ID) { + isInitialized = false; + isChainDuplication = false; + injectInAllIndexes = false; + injectInAllOperators = false; + } - auto myIf = [&](Use &operand) { - if (isa(operand.getUser())) - return false; - return true; - }; + // Duplicate a single arithmetic instruction + void duplicateInstruction(Instruction *inst) { - inst->replaceUsesWithIf(funret, myIf); - } + // Clone the instruction and reassign the operands. + Instruction *duplicatedInst = inst->clone(); + for (unsigned int i = 0; i < duplicatedInst->getNumOperands(); i++) { + duplicatedInst->setOperand(i, inst->getOperand(i)); + } - // Duplicate a chain of arithmetic instructions. - void duplicateInstructionChain(vector insVector) - { + // Copy metadata + MDNode *mdnode = inst->getMetadata("llfi_index"); + inst->setMetadata("llfi_index", nullptr); + duplicatedInst->setMetadata("llfi_index", mdnode); + addMetadata(duplicatedInst, "Duplicated_Instruction"); - llvm::ValueToValueMapTy vmap; - vector new_instructions; + // Insert the duplicate instruction + duplicatedInst->insertBefore(inst->getNextNode()->getIterator()); - // Keep track of the last instructions of the instruction chain. - Instruction* lastInst, *lastInstDupl; + IRBuilder<> IRB(inst->getParent()); + IRB.SetInsertPoint(duplicatedInst->getNextNode()); - if (insVector.size() > 0) { - for (auto *inst: insVector) { + auto Fn = inst->getFunction()->getParent()->getOrInsertFunction( + "compareFloatValues", Type::getFloatTy(inst->getContext()), + Type::getFloatTy(inst->getContext()), + Type::getFloatTy(inst->getContext())); - // Clone the instruction - Instruction* duplicatedInst = inst->clone(); + Value *funret = IRB.CreateCall(Fn, {inst, duplicatedInst}); - lastInst = inst; - lastInstDupl = duplicatedInst; + auto myIf = [&](Use &operand) { + if (isa(operand.getUser())) + return false; + return true; + }; - for(unsigned int i = 0; i < duplicatedInst->getNumOperands(); i++){ - duplicatedInst->setOperand(i, inst->getOperand(i)); - } + inst->replaceUsesWithIf(funret, myIf); + } - // Copy metadata - MDNode *mdnode = inst->getMetadata("llfi_index"); - inst->setMetadata("llfi_index", NULL); - duplicatedInst->setMetadata("llfi_index", mdnode); - addMetadata(duplicatedInst, "Duplicated_Instruction_In_Chain"); + // Duplicate a chain of arithmetic instructions. + void duplicateInstructionChain(vector insVector) { - // Insert the duplicated instruction in the LLVM IR - duplicatedInst->insertAfter(inst); + llvm::ValueToValueMapTy vmap; + vector new_instructions; - new_instructions.push_back(duplicatedInst); - vmap[inst] = duplicatedInst; - } - } + // Keep track of the last instructions of the instruction chain. + Instruction *lastInst = nullptr, *lastInstDupl = nullptr; - for (auto *i : new_instructions) { - llvm::RemapInstruction(i, vmap, RF_NoModuleLevelChanges | - RF_IgnoreMissingLocals); - } + if (!insVector.empty()) { + for (auto *inst : insVector) { - // Insert the compareFloatValue function - IRBuilder<> IRB(lastInst->getParent()); - IRB.SetInsertPoint(lastInstDupl->getNextNode()); + // Clone the instruction + Instruction *duplicatedInst = inst->clone(); - auto Fn = lastInst->getFunction()->getParent()-> - getOrInsertFunction("compareFloatValues", - Type::getFloatTy(lastInst->getContext()), - Type::getFloatTy(lastInst->getContext()), - Type::getFloatTy(lastInst->getContext()) - ); + lastInst = inst; + lastInstDupl = duplicatedInst; - Value *funret = IRB.CreateCall(Fn, {lastInst, lastInstDupl}); - - // Replace all use of the arithmatic instruction with the function - // return value - auto myIf = [&](Use &operand) { - if (isa(operand.getUser())) - return false; - return true; - }; - - lastInst->replaceUsesWithIf(funret, myIf); + for (unsigned int i = 0; i < duplicatedInst->getNumOperands(); i++) { + duplicatedInst->setOperand(i, inst->getOperand(i)); } - bool isArithmeticInstruction(Instruction* inst) - { - // Don't do instruction duplication in FCmp. - if (inst != NULL && (inst->getOpcode() == Instruction::FAdd || - inst->getOpcode() == Instruction::FSub || - inst->getOpcode() == Instruction::FMul || - inst->getOpcode() == Instruction::FDiv)) - return true; - else - return false; - } + // Copy metadata + MDNode *mdnode = inst->getMetadata("llfi_index"); + inst->setMetadata("llfi_index", nullptr); + duplicatedInst->setMetadata("llfi_index", mdnode); + addMetadata(duplicatedInst, "Duplicated_Instruction_In_Chain"); - bool checkInstructionIndex(Instruction* inst) { + // Insert the duplicated instruction in the LLVM IR + duplicatedInst->insertAfter(inst); - if (injectInAllIndexes) return true; + new_instructions.push_back(duplicatedInst); + vmap[inst] = duplicatedInst; + } + } - MDNode *mdnode = inst->getMetadata("llfi_index"); - long vindex = 0; - if (mdnode) { - ConstantInt *cns_index = mdconst::dyn_extract(mdnode->getOperand(0)); - vindex = cns_index->getSExtValue(); - } + for (auto *i : new_instructions) { + llvm::RemapInstruction(i, vmap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + } - for(long idx : llfiIndexes) { + // Insert the compareFloatValue function + assert(lastInst != nullptr && lastInstDupl != nullptr); + IRBuilder<> IRB(lastInst->getParent()); + IRB.SetInsertPoint(lastInstDupl->getNextNode()); - if (idx == vindex) - return true; - } + auto Fn = lastInst->getFunction()->getParent()->getOrInsertFunction( + "compareFloatValues", Type::getFloatTy(lastInst->getContext()), + Type::getFloatTy(lastInst->getContext()), + Type::getFloatTy(lastInst->getContext())); - return false; - } + Value *funret = IRB.CreateCall(Fn, {lastInst, lastInstDupl}); - bool doArithmeticInstructionDuplication(Function& F) - { - vector arithInst; - bool isCustomTensorOperator = false; + // Replace all use of the arithmatic instruction with the function + // return value + auto myIf = [&](Use &operand) { + if (isa(operand.getUser())) + return false; + return true; + }; - // Find all the floating-point arithmetic instructions in this function - for (BasicBlock &bb : F) { - for (BasicBlock::const_iterator i = bb.begin(); i != bb.end(); - ++i) { - Instruction* inst = const_cast(&*i); + lastInst->replaceUsesWithIf(funret, myIf); + } + + bool isArithmeticInstruction(Instruction *inst) { + // Don't do instruction duplication in FCmp. + if (inst != nullptr && (inst->getOpcode() == Instruction::FAdd || + inst->getOpcode() == Instruction::FSub || + inst->getOpcode() == Instruction::FMul || + inst->getOpcode() == Instruction::FDiv)) + return true; + else + return false; + } + + bool checkInstructionIndex(Instruction *inst) { + + if (injectInAllIndexes) + return true; + + MDNode *mdnode = inst->getMetadata("llfi_index"); + long vindex = 0; + if (mdnode) { + ConstantInt *cns_index = + mdconst::dyn_extract(mdnode->getOperand(0)); + vindex = cns_index->getSExtValue(); + } - if (inst->getOpcode() == Instruction::Call){ - CallInst* callinst = dyn_cast(inst); + for (long idx : llfiIndexes) { - // If this is OMInstrument function? - if ((callinst->getCalledFunction())->getName() == - "OMInstrumentPoint") { + if (idx == vindex) + return true; + } + + return false; + } - Value* arg1 = callinst->getArgOperand(0); - Value* arg2 = callinst->getArgOperand(1); + bool doArithmeticInstructionDuplication(Function &F) { + vector arithInst; + bool isCustomTensorOperator = false; - ConstantInt* ci1 = dyn_cast(arg1); - ConstantInt* ci2 = dyn_cast(arg2); + // Find all the floating-point arithmetic instructions in this function + for (BasicBlock &bb : F) { + for (BasicBlock::const_iterator i = bb.begin(); i != bb.end(); ++i) { + Instruction *inst = const_cast(&*i); - int64_t argValue1 = ci1->getSExtValue(); - int64_t argValue2 = ci2->getSExtValue(); + if (inst->getOpcode() == Instruction::Call) { + CallInst *callinst = cast(inst); - if (argValue2 == 2 && shouldInjectFault(argValue1)) { + // If this is OMInstrument function? + if (callinst->getCalledFunction() && + callinst->getCalledFunction()->getName() == "OMInstrumentPoint") { - // Inject fault! - isCustomTensorOperator = true; - } + Value *arg1 = callinst->getArgOperand(0); + Value *arg2 = callinst->getArgOperand(1); - if (argValue2 == 1 && shouldInjectFault(argValue1)) { + ConstantInt *ci1 = dyn_cast(arg1); + ConstantInt *ci2 = dyn_cast(arg2); + if (!ci1 || !ci2) + continue; - // Set this to false after the operator ends. - isCustomTensorOperator = false; - } - } - } + int64_t argValue1 = ci1->getSExtValue(); + int64_t argValue2 = ci2->getSExtValue(); - if (isCustomTensorOperator && isArithmeticInstruction(inst) && checkInstructionIndex(inst)) { + if (argValue2 == 2 && shouldInjectFault(argValue1)) { - arithInst.push_back(inst); - } - } + // Inject fault! + isCustomTensorOperator = true; } - // Then duplicate the arithmetic instructions. - for (auto ins : arithInst){ + if (argValue2 == 1 && shouldInjectFault(argValue1)) { - duplicateInstruction(ins); + // Set this to false after the operator ends. + isCustomTensorOperator = false; } - - return true; + } } - bool doArithmeticChainDuplication(Function& F) - { - vector> arithInst; - bool isCustomTensorOperator = false; + if (isCustomTensorOperator && isArithmeticInstruction(inst) && + checkInstructionIndex(inst)) { - // Find all the floating-point arithmetic instructions in this function - for (BasicBlock &bb : F) { - for (BasicBlock::const_iterator i = bb.begin(); i != bb.end(); - ++i) { - Instruction* inst = const_cast(&*i); - - if (inst->getOpcode() == Instruction::Call){ - CallInst* callinst = dyn_cast(inst); + arithInst.push_back(inst); + } + } + } - // If this is OMInstrument function? - if ((callinst->getCalledFunction())->getName() == - "OMInstrumentPoint") { + // Then duplicate the arithmetic instructions. + for (auto ins : arithInst) { - Value* arg1 = callinst->getArgOperand(0); - Value* arg2 = callinst->getArgOperand(1); + duplicateInstruction(ins); + } - ConstantInt* ci1 = dyn_cast(arg1); - ConstantInt* ci2 = dyn_cast(arg2); + return true; + } - int64_t argValue1 = ci1->getSExtValue(); - int64_t argValue2 = ci2->getSExtValue(); + bool doArithmeticChainDuplication(Function &F) { + vector> arithInst; + bool isCustomTensorOperator = false; - if (argValue2 == 2 && shouldInjectFault(argValue1)) { + // Find all the floating-point arithmetic instructions in this function + for (BasicBlock &bb : F) { + for (BasicBlock::const_iterator i = bb.begin(); i != bb.end(); ++i) { + Instruction *inst = const_cast(&*i); - // Inject fault! - isCustomTensorOperator = true; - } + if (inst->getOpcode() == Instruction::Call) { + CallInst *callinst = cast(inst); - if (argValue2 == 1 && shouldInjectFault(argValue1)) { + // If this is OMInstrument function? + if (callinst->getCalledFunction() && + callinst->getCalledFunction()->getName() == "OMInstrumentPoint") { - // Set this to false after the operator ends. - isCustomTensorOperator = false; - } - } - } + Value *arg1 = callinst->getArgOperand(0); + Value *arg2 = callinst->getArgOperand(1); - if (isCustomTensorOperator && isArithmeticInstruction(inst) && checkInstructionIndex(inst)){ + ConstantInt *ci1 = dyn_cast(arg1); + ConstantInt *ci2 = dyn_cast(arg2); + if (!ci1 || !ci2) + continue; - vector temp; - temp.push_back(inst); + int64_t argValue1 = ci1->getSExtValue(); + int64_t argValue2 = ci2->getSExtValue(); - // Detects chain of arithmetic instructions - Instruction* currInst = inst; - while (true) { - if (isArithmeticInstruction(currInst->getNextNonDebugInstruction())&& - checkInstructionIndex(currInst->getNextNonDebugInstruction())) { - ++i; - currInst = currInst->getNextNonDebugInstruction(); - temp.push_back(currInst); - } - else - break; - } + if (argValue2 == 2 && shouldInjectFault(argValue1)) { - arithInst.push_back(temp); - } - } + // Inject fault! + isCustomTensorOperator = true; } - // Then duplicate the arithmetic instructions. - for (auto insVector : arithInst){ + if (argValue2 == 1 && shouldInjectFault(argValue1)) { - if (insVector.size() == 1) - duplicateInstruction(insVector[0]); - else if (insVector.size() > 1){ - duplicateInstructionChain(insVector); - } - else - assert(false && "Size of insVector can't be zero!"); + // Set this to false after the operator ends. + isCustomTensorOperator = false; } + } + } - return true; + if (isCustomTensorOperator && isArithmeticInstruction(inst) && + checkInstructionIndex(inst)) { + + vector temp; + temp.push_back(inst); + + // Detects chain of arithmetic instructions + Instruction *currInst = inst; + while (true) { + if (isArithmeticInstruction( + currInst->getNextNonDebugInstruction()) && + checkInstructionIndex(currInst->getNextNonDebugInstruction())) { + ++i; + currInst = currInst->getNextNonDebugInstruction(); + temp.push_back(currInst); + } else + break; + } + + arithInst.push_back(temp); } + } + } - bool shouldInjectFault(int64_t number) { + // Then duplicate the arithmetic instructions. + for (const auto &insVector : arithInst) { - if (injectInAllOperators) return true; + if (insVector.size() == 1) + duplicateInstruction(insVector[0]); + else if (insVector.size() > 1) { + duplicateInstructionChain(insVector); + } else + assert(false && "Size of insVector can't be zero!"); + } - // If the operator isn't present in the map. - for (auto num : operatorValues) { + return true; + } - if (number == num) - return true; - } + bool shouldInjectFault(int64_t number) { - return false; - } + if (injectInAllOperators) + return true; - bool runOnMainGraph(Function& F) - { - // Parse input options. - if (!isInitialized) { - isInitialized = true; - initializeGranularityAndLayerName(llfiIndex, layerName, enableChainDuplication); - } + // If the operator isn't present in the map. + for (auto num : operatorValues) { - if (isChainDuplication){ - return doArithmeticChainDuplication(F); - } - else { - return doArithmeticInstructionDuplication(F); - } + if (number == num) + return true; + } - return false; - } + return false; + } - virtual bool runOnFunction(Function &F) - { + bool runOnMainGraph(Function &F) { + // Parse input options. + if (!isInitialized) { + isInitialized = true; + initializeGranularityAndLayerName(llfiIndex.getValue(), + layerName.getValue(), + (bool)enableChainDuplication); + } - if (F.getName() == "main_graph") { - return runOnMainGraph(F); - } + if (isChainDuplication) { + return doArithmeticChainDuplication(F); + } else { + return doArithmeticInstructionDuplication(F); + } + } - return false; - } - }; + bool runOnFunction(Function &F) override { - char InstructionDuplicationPass::ID = 0; + if (F.getName() == "main_graph") { + return runOnMainGraph(F); + } - static RegisterPass - X("InstructionDuplicationPass", "Automatic Duplication of ML applications", - false, false); + return false; + } +}; + +char InstructionDuplicationPass::ID = 0; + +static RegisterPass + X("InstructionDuplicationPass", "Automatic Duplication of ML applications", + false, false); + +// New pass manager wrapper — iterates over all functions just like the +// legacy FunctionPass, preserving the existing per-function logic. +struct NewInstructionDuplicationPass + : llvm::PassInfoMixin { + + llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &) { + InstructionDuplicationPass legacy; + bool changed = false; + for (Function &F : M) + changed |= legacy.runOnFunction(F); + return changed ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all(); + } + + static bool isRequired() { return true; } +}; +} // namespace llfi + +//----------------------------------------------------------------------------- +// New PM plugin entry point for SEDPasses.so +//----------------------------------------------------------------------------- +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo +llvmGetPassPluginInfo() { + return {LLVM_PLUGIN_API_VERSION, "SEDPasses", LLVM_VERSION_STRING, + [](llvm::PassBuilder &PB) { + PB.registerPipelineParsingCallback( + [](llvm::StringRef Name, llvm::ModulePassManager &MPM, + llvm::ArrayRef) { + if (Name == "InstructionDuplicationPass") { + MPM.addPass(llfi::NewInstructionDuplicationPass()); + return true; + } + return false; + }); + }}; } diff --git a/llvm_passes/instruction_duplication/README.md b/llvm_passes/instruction_duplication/README.md index f4d425f4..9119469a 100644 --- a/llvm_passes/instruction_duplication/README.md +++ b/llvm_passes/instruction_duplication/README.md @@ -20,16 +20,17 @@ sh compile_shrd_lib.sh ``` # Perform instruction duplication -$LLVM_BUILD_PATH/bin/opt -load ../../../../build/llvm_passes/instruction_duplication/SEDPasses.so \ - --InstructionDuplicationPass -operatorName=all \ - --enableChainDuplication --enable-new-pm=0 -S model.ll -o model_change.ll \ +$LLVM_BUILD_PATH/bin/opt \ + -load-pass-plugin ../../../../build/llvm_passes/instruction_duplication/SEDPasses.so \ + --passes=InstructionDuplicationPass --operatorName=all \ + --enableChainDuplication -S model.ll -o model_change.ll \ > /dev/null # Link the comparision checks llvm-link -o model_change.ll -S model_change.ll SIDHelperFunctions.ll # Inline the comparison checks -$LLVM_BUILD_PATH/bin/opt model_change.ll -always-inline -S -o model.ll +$LLVM_BUILD_PATH/bin/opt --passes=always-inline -S model_change.ll -o model.ll ``` diff --git a/llvm_passes/instruction_duplication/shared_lib/SIDHelperFunctions.cpp b/llvm_passes/instruction_duplication/shared_lib/SIDHelperFunctions.cpp index 38656336..43130128 100644 --- a/llvm_passes/instruction_duplication/shared_lib/SIDHelperFunctions.cpp +++ b/llvm_passes/instruction_duplication/shared_lib/SIDHelperFunctions.cpp @@ -1,46 +1,46 @@ // Representation of a 32-bit float typedef union { - float f; - struct { - unsigned int mantisa : 23; - unsigned int exponent : 8; - unsigned int sign : 1; - } parts; + float f; + struct { + unsigned int mantisa : 23; + unsigned int exponent : 8; + unsigned int sign : 1; + } parts; } float_cast; extern "C" { -float __attribute__((always_inline)) compareFloatValues(float val1, float val2) { - - //int v1 = *reinterpret_cast(&val1); - //int v2 = *reinterpret_cast(&val2); - - /* - if (!(v1 ^ v2)) - return val1; - else { - // Take the absolute and return the min - // Why even take the absolute? Most activation functions don't allow - // passing large negative values. - - uint32_t temp = v1 >> 31; - v1 ^= temp; - v1 += temp & 1; - temp = v2 >> 31; - v2 ^= temp; - v2 += temp & 1; - */ - - - float_cast f1 = {.f = val1}; - float_cast f2 = {.f = val2}; - float_cast retval = {.f = 0.0}; - retval.parts.mantisa = f1.parts.mantisa & f2.parts.mantisa; - retval.parts.exponent = f1.parts.exponent & f2.parts.exponent; - retval.parts.sign = f1.parts.sign & f2.parts.sign; - return retval.f; - -// return (val1 < val2) ? val1 : val2; +float __attribute__((always_inline)) compareFloatValues(float val1, + float val2) { + + // int v1 = *reinterpret_cast(&val1); + // int v2 = *reinterpret_cast(&val2); + + /* + if (!(v1 ^ v2)) + return val1; + else { + // Take the absolute and return the min + // Why even take the absolute? Most activation functions don't allow + // passing large negative values. + + uint32_t temp = v1 >> 31; + v1 ^= temp; + v1 += temp & 1; + temp = v2 >> 31; + v2 ^= temp; + v2 += temp & 1; + */ + + float_cast f1 = {.f = val1}; + float_cast f2 = {.f = val2}; + float_cast retval = {.f = 0.0}; + retval.parts.mantisa = f1.parts.mantisa & f2.parts.mantisa; + retval.parts.exponent = f1.parts.exponent & f2.parts.exponent; + retval.parts.sign = f1.parts.sign & f2.parts.sign; + return retval.f; + + // return (val1 < val2) ? val1 : val2; } } diff --git a/llvm_passes/instruction_duplication/shared_lib/build.sh b/llvm_passes/instruction_duplication/shared_lib/build.sh index e4512b7d..5cf5534d 100644 --- a/llvm_passes/instruction_duplication/shared_lib/build.sh +++ b/llvm_passes/instruction_duplication/shared_lib/build.sh @@ -1 +1,2 @@ -../../../../../llvm-project/build/bin/clang -shared SIDHelperFunctions.cpp -o libSIDHelperFunctions.so +CLANG=${LLVM_GXX_BIN_DIR:+$LLVM_GXX_BIN_DIR/}clang +${CLANG} -shared SIDHelperFunctions.cpp -o libSIDHelperFunctions.so diff --git a/llvm_passes/instruction_duplication/shared_lib/compile_shrd_lib.sh b/llvm_passes/instruction_duplication/shared_lib/compile_shrd_lib.sh index 925be55f..40f4df28 100644 --- a/llvm_passes/instruction_duplication/shared_lib/compile_shrd_lib.sh +++ b/llvm_passes/instruction_duplication/shared_lib/compile_shrd_lib.sh @@ -1,3 +1,4 @@ #!/bin/sh -clang++ -S -fno-inline -fPIC -emit-llvm SIDHelperFunctions.cpp -o SIDHelperFunctions.ll -O3 +CLANGXX=${LLVM_GXX_BIN_DIR:+$LLVM_GXX_BIN_DIR/}clang++ +${CLANGXX} -S -fno-inline -fPIC -emit-llvm SIDHelperFunctions.cpp -o SIDHelperFunctions.ll -O3 diff --git a/llvm_passes/software_failures/_SoftwareFaultRegSelectors.cpp b/llvm_passes/software_failures/_SoftwareFaultRegSelectors.cpp deleted file mode 100644 index a065af92..00000000 --- a/llvm_passes/software_failures/_SoftwareFaultRegSelectors.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "_SoftwareFaultRegSelectors.h" - -using namespace std; -namespace llfi { - bool FuncArgRegSelector::isRegofInstFITarget(Value *reg, Instruction *inst){ - if(isa(inst) == false){ - return false; - }else{ - CallInst* CI = dyn_cast(inst); - if(this->specified_arg == true){ - if(reg == CI->getArgOperand(this->pos_argument)){ - return true; - }else return false; - }else{ - for(int i = 0; igetNumArgOperands(); i++){ - if(reg == CI->getArgOperand(i)) return true; - } - return false; - } - } - } - bool FuncArgRegSelector::isRegofInstFITarget(Value *reg, Instruction *inst, int pos){ - if(specified_arg == true) - return isRegofInstFITarget(reg, inst) && pos == this->pos_argument; - } - - bool FuncDestRegSelector::isRegofInstFITarget(Value *reg, Instruction *inst){ - if(isa(inst) == false){ - return false; - }else{ - if(reg == inst) return true; - else return false; - } - } - - bool RetValRegSelector::isRegofInstFITarget(Value *reg, Instruction *inst){ - if(isa(inst)){ - ReturnInst* RI = dyn_cast(inst); - if(reg == RI->getReturnValue()) return true; - else return false; - }else return false; - } - - static RegisterFIRegSelector A("FuncArgRegSelector", new FuncArgRegSelector()); - static RegisterFIRegSelector B("RetValRegSelector", new RetValRegSelector()); - static RegisterFIRegSelector C("FuncDestRegSelector", new FuncDestRegSelector()); -} - - diff --git a/llvm_passes/software_failures/_SoftwareFaultRegSelectors.h b/llvm_passes/software_failures/_SoftwareFaultRegSelectors.h deleted file mode 100644 index dd662209..00000000 --- a/llvm_passes/software_failures/_SoftwareFaultRegSelectors.h +++ /dev/null @@ -1,40 +0,0 @@ -#include "llvm/IR/Value.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Constants.h" -#include "FIInstSelector.h" -#include "FIRegSelector.h" -#include "FICustomSelectorManager.h" - -#include "llvm/IR/IntrinsicInst.h" -#include -#include -#include - -using namespace std; -namespace llfi { - class FuncArgRegSelector: public SoftwareFIRegSelector { - public: - FuncArgRegSelector(int target_arg) : pos_argument(target_arg), specified_arg(true) {}; - FuncArgRegSelector():pos_argument(0), specified_arg(false) {}; - private: - int pos_argument; - bool specified_arg; - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst); - virtual bool isRegofInstFITarget(Value* reg, Instruction* inst, int pos); - }; - - class FuncDestRegSelector: public SoftwareFIRegSelector { - private: - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst); - - }; - - class RetValRegSelector: public SoftwareFIRegSelector { - private: - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst); - - }; - -} - - diff --git a/llvm_passes/software_failures/_Timing_HighFrequentEventSelector.cpp b/llvm_passes/software_failures/_Timing_HighFrequentEventSelector.cpp deleted file mode 100644 index 796a0483..00000000 --- a/llvm_passes/software_failures/_Timing_HighFrequentEventSelector.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "llvm/Pass.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Support/CFG.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" - -#include "FIInstSelector.h" -#include "FICustomSelectorManager.h" -#include "_SoftwareFaultRegSelectors.h" - -#include -#include -#include -#include - -using namespace llvm; -namespace llfi { - class _Timing_HighFrequentEventInstSelector : public SoftwareFIInstSelector { - public: - _Timing_HighFrequentEventInstSelector() { - if (funcNames.size() == 0) { - funcNames.insert(std::string("fread")); - funcNames.insert(std::string("fopen")); - funcNames.insert(std::string("fwrite")); - } - } - - virtual void getCompileTimeInfo(std::map& info){ - info["failure_class"] = "Timing"; - info["failure_mode"] = "HighFrequentEvent"; - for(std::set::iterator SI = funcNames.begin(); SI != funcNames.end(); SI++) { - info["targets"] += *SI + "()/"; - } - info["targets"] += "return"; - info["injector"] = "SleepInjector"; - } - - private: - static std::set funcNames; - - virtual bool isInstFITarget(Instruction* inst) { - if (isa(inst)) { - CallInst* CI = dyn_cast(inst); - Function* called_func = CI->getCalledFunction(); - if (called_func == NULL) { - return false; - } - std::string func_name = std::string(called_func->getName()); - if (funcNames.find(func_name) != funcNames.end()) { - return true; - } else { - return false; - } - } else { - return isa(inst); - } - } - }; - - std::set _Timing_HighFrequentEventInstSelector::funcNames; - - class _Timing_HighFrequentEventRegSelector : public SoftwareFIRegSelector { - private: - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst) { - if (isa(inst)) { - CallInst* CI = dyn_cast(inst); - Function* called_func = CI->getCalledFunction(); - if (called_func == NULL) { - return false; - } - return reg == CI; // selects dst register - } else if (isa(inst)) { - ReturnInst* RI = dyn_cast(inst); - return reg == RI->getReturnValue(); - } else { - return false; - } - } - }; - - static RegisterFIInstSelector A("HighFrequentEvent(Timing)", new _Timing_HighFrequentEventInstSelector()); - static RegisterFIRegSelector B("HighFrequentEvent(Timing)", new _Timing_HighFrequentEventRegSelector()); -} diff --git a/migration.md b/migration.md new file mode 100644 index 00000000..1f788037 --- /dev/null +++ b/migration.md @@ -0,0 +1,338 @@ +# LLTFI LLVM Upgrade Migration Guide + +This document describes the work required to upgrade LLTFI from LLVM 15 to a +current LLVM release (17–20). Tasks are split between those that require a +human and those that Claude Code can handle autonomously once the build +environment is ready. + +Migrated from: **LLVM 15.0** (`/usr/lib/llvm-15`) +Current LLVM version: **LLVM 20.1** (`/usr/lib/llvm-20`) + +--- + +## Status: further along than it appears + +The most difficult part of a major LLVM upgrade — migrating from the legacy +pass manager (PM) to the new PM — is **already done for all core LLTFI +passes**. `RegisterPasses.cpp` exposes `llvmGetPassPluginInfo()`, every core +pass has a `PassInfoMixin` wrapper, and `instrument.py` already uses +`-load-pass-plugin`. What remains is a set of API-level fixups and one +incomplete pass migration. + +--- + +## Human tasks (cannot be delegated to Claude Code) + +### H-1 — Install LLVM 17+ and update the build configuration ✅ DONE +**Estimated time: 2–4 hours** + +This is the only hard prerequisite. Claude Code cannot change what is installed +on the machine, so the compile-verify-fix loop that de-risks the code changes +cannot begin until this is done. + +Steps: +1. Install a current LLVM release via the LLVM apt repository: + ```bash + wget https://apt.llvm.org/llvm.sh && chmod +x llvm.sh + sudo ./llvm.sh 20 # or 17/18/19 + ``` +2. Update the build configuration to point at the new installation: + ```bash + # Delete the old build root first + rm -rf /home/karthik/Programs/LLTFI-build + + ./setup -LLFI_BUILD_ROOT /home/karthik/Programs/LLTFI-build \ + -LLVM_SRC_ROOT /home/karthik/Programs/llvm-project \ + -LLVM_DST_ROOT /usr/lib/llvm-20 \ + -LLVM_GXX_BIN_DIR /usr/lib/llvm-20/bin + ``` +3. Attempt an initial build (`cd LLTFI-build && make`) and record the errors. + The output of this first build attempt is the input for Claude Code's work. + +### H-2 — Review IRBuilder insertion-point correctness +**Estimated time: 2–3 hours** + +The instruction-construction API fixes (task C-1 below) replace +`new AllocaInst/StoreInst/LoadInst` calls with IRBuilder equivalents. These +compile cleanly whether correct or not, but an incorrect insertion point places +an instruction in the wrong basic block or wrong position, producing +miscompiled IR. A human who understands the intended pass semantics should +review these diffs specifically before merging. The key files to scrutinise +are `FaultInjectionPass.cpp` and `InstTracePass.cpp`. + +**Claude Code preliminary review (2026-04-14):** Both files were inspected and +the insertion points appear semantically correct: + +- `FaultInjectionPass.cpp` lines 229–230, 266: uses the `BasicBlock*` + insertAtEnd form (`new AllocaInst(type, 0, name, block)`), which inserts at + the end of the given block. Allocas go in the entry block (standard practice); + the store and load go in the respective exit block immediately after the + injected value is computed. This is the same logical placement as before the + API change. +- `InstTracePass.cpp` lines 141–142, 144, 152, 170: uses `BasicBlock::iterator` + from `getFirstNonPHIOrDbgOrLifetime()` for alloca insertion (correct: before + any non-PHI/dbg/lifetime instruction in the entry block), and + `insertPoint->getIterator()` for stores (correct: inserts immediately before + the trace call). All 21 tests pass with these changes. + +The remaining human task is to verify the *logical* placement makes sense for +the pass's intended semantics, not just that it compiles and passes tests. + +### H-3 — Provide onnx-mlir environment for real-model validation +**Estimated time: 1–2 hours** + +The only step that requires a human is making onnx-mlir available: + +1. Install onnx-mlir or set `ONNX_MLIR_BUILD` to point at an existing build. +2. Run `compile.sh` in `sample_programs/ml_sample_programs/vision_models/mnist/` + to produce `model.ll`. +3. Build `SIDHelperFunctions.ll` (needed for end-to-end numerical run): + ```bash + cd llvm_passes/instruction_duplication/shared_lib + sh compile_shrd_lib.sh + ``` + +Once `model.ll` and `SIDHelperFunctions.ll` exist, Claude Code (or the test +suite) takes over automatically. Two tests in `test_instruction_duplication.py` +cover the rest: + +- **`real_model_structural`** — applies the pass to the real onnx-mlir IR and + verifies that `compareFloatValues` calls are inserted and arithmetic + instructions are duplicated. This proves the pass handles genuine onnx-mlir + IR patterns, not just the synthetic fixtures used in the seven other tests. +- **`real_model_end_to_end`** — runs both the baseline `model.ll` and the + duplicated+inlined model through `lli`, then asserts their outputs are + identical. Because `compareFloatValues(x, x) == x` (bitwise AND of equal + floats is the float itself), the outputs must match when no fault is injected. + Any divergence indicates a pass transformation bug. + +Both tests SKIP gracefully when their prerequisites are absent, so the suite +continues to report 0 failures even before onnx-mlir is set up. Run them via: + +```bash +cd /path/to/LLTFI-build/test_suite +python3 SCRIPTS/test_instruction_duplication.py +``` + +### H-4 — Final test sign-off ✅ DONE + +Run the full test suite against the new LLVM version and confirm no regressions: +```bash +cd /path/to/LLTFI-build/test_suite +python3 SCRIPTS/llfi_test --all_cpp # expect 21/21 +python3 SCRIPTS/llfi_test --all_ml # expect all non-SKIP to pass +``` +**Result: 21/21 PASS.** All hardware fault, trace tool, makefile +generation tests pass against LLVM 20.1. The `--all_ml` tests that +require optional dependencies (onnx-mlir, TensorFlow, PyTorch) are reported as +SKIP (not FAIL) on machines where those are not installed. + +--- + +## Claude Code tasks (can be done autonomously once LLVM 17+ is installed) + +### C-1 — Replace deprecated instruction-construction API (9 sites) ✅ DONE +**Estimated time: 1 day** + +LLVM 16/17 removed the `InsertBefore` Instruction-pointer parameter from +`AllocaInst`, `StoreInst`, and `LoadInst` constructors. All 9 affected sites +must be replaced with IRBuilder equivalents. + +Affected files and sites: + +| File | Sites | Pattern | +|------|-------|---------| +| `llvm_passes/core/FaultInjectionPass.cpp` | 3 | `new AllocaInst(...)`, `new StoreInst(...)`, `new LoadInst(...)` | +| `llvm_passes/core/InstTracePass.cpp` | 6 | `new AllocaInst(...)` ×3, `new StoreInst(...)` ×3 | + +Example replacement: +```cpp +// Before (LLVM 15) +AllocaInst *tmploc = new AllocaInst(fitype, 0, "tmploc", entryblock); + +// After (LLVM 17+) +IRBuilder<> builder(&entryblock->front()); +AllocaInst *tmploc = builder.CreateAlloca(fitype, nullptr, "tmploc"); +``` + +Note: `CallInst::Create(func, args, "", insertPoint)` with a raw `Instruction *` +as the last argument may also need updating to use a `BasicBlock::iterator`; +check each call site after the build reveals errors. + +### C-2 — Fix iterator return-type changes ✅ DONE +**Estimated time: 2–3 hours** + +Two methods changed their return type from `Instruction *` to +`BasicBlock::iterator` in LLVM 17: + +- `getFirstNonPHIOrDbgOrLifetime()` — used in `InstTracePass.cpp:123` +- `getNextNonDebugInstruction()` — used in `Utils.cpp:125` and + `InstructionDuplication.cpp:440–443` + +At each call site, either dereference the iterator (`&*iter`) or update the +surrounding code to work with iterators directly. + +### C-3 — Migrate InstructionDuplication to the new pass manager ✅ DONE +**Estimated time: 2–3 days** + +`InstructionDuplication.cpp` is the only remaining pass still using the legacy +PM. It needs to be migrated to `PassInfoMixin` to match all other LLTFI passes. + +Changes required: + +1. **`InstructionDuplication.cpp`**: Change class base from `FunctionPass` to + `PassInfoMixin`. Rename `runOnFunction(Function + &F)` to `run(Function &F, FunctionAnalysisManager &AM)` returning + `PreservedAnalyses`. Remove `static char ID` and `RegisterPass<>`. + The five existing `PassInfoMixin` passes in the codebase serve as templates. + +2. **`RegisterPasses.cpp`**: Add a `registerPipelineParsingCallback` entry for + `InstructionDuplicationPass`, following the existing pattern for the other + passes. Decide whether `SEDPasses.so` should be merged into the main + `llfi-passes.so` or kept separate; keeping it separate is simpler. + +3. **`llvm_passes/instruction_duplication/CMakeLists.txt`**: Update to add the + `llvmGetPassPluginInfo` export if `SEDPasses.so` stays separate, or adjust + linking if merged. + +4. **`shared_lib/build.sh` and `compile_shrd_lib.sh`**: Update `opt` invocation + to remove `-load` / `--enable-new-pm=0` and use `-load-pass-plugin` with + `--passes=InstructionDuplicationPass`. + +5. **`README.md`** (in `instruction_duplication/`): Update example `opt` + command. + +6. **`test_suite/SCRIPTS/test_instruction_duplication.py`**: Update `_run_pass` + to use `-load-pass-plugin` and `--passes=InstructionDuplicationPass` instead + of `-load` / `--InstructionDuplicationPass` / `--enable-new-pm=0`. + +### C-4 — Fix `Module::getGlobalList().push_back()` in `Utils.cpp` ✅ DONE +**Estimated time: 30 minutes** + +`getGlobalList()` was removed in LLVM 17. The one call site in +`Utils.cpp:206` creates a `GlobalVariable` and appends it to the module. +Replace with the `GlobalVariable` constructor overload that takes a `Module *` +directly (which inserts the variable into the module automatically): + +```cpp +// Before (LLVM 15) +nameStr = new GlobalVariable(name_c->getType(), true, + GlobalVariable::InternalLinkage, name_c, gv_nameStr.c_str()); +M.getGlobalList().push_back(nameStr); + +// After (LLVM 17+) +nameStr = new GlobalVariable(M, name_c->getType(), true, + GlobalVariable::InternalLinkage, name_c, gv_nameStr.c_str()); +``` + +### C-6 — Iterative build-fix loop ✅ DONE +**Estimated time: 1 week (wall clock; most of this is Claude Code running builds)** + +After applying C-1 through C-5, run `make` and address any remaining compiler +errors introduced by LLVM 17–20 API changes not captured above. LLVM releases +between 17 and 20 introduce additional deprecations (e.g. changes to +`Value::use_iterator`, `DebugLoc` APIs, `MDNode` helpers) that may surface +depending on the exact target version. Claude Code can drive this loop: +read the error, identify the fix, apply it, rebuild. + +### C-7 — C++ static analysis and formatting cleanup ✅ DONE + +After the build was clean, `clang-format-20` and `clang-tidy-20` were run +across all hand-written C++ sources under `llvm_passes/`. In addition to style +issues, clang-tidy surfaced several real bugs: + +| Bug | File | Fix | +|-----|------|-----| +| Double-free in singleton destructor | `Controller.cpp` | Removed `delete ctrl` from `~Controller()` — object is not heap-allocated by the time the destructor runs | +| File stream leak | `LLFIDotGraphPass.cpp` | Added missing `fclose(outputFile)` | +| Null dereference via unchecked `fopen` | `GenLLFIIndexPass.cpp` | Moved `fclose` inside the `if (outputFile)` block | +| Uninitialized field `isChainDuplication` | `InstructionDuplicationPass` constructor | Added explicit `isChainDuplication = false` initializer | +| Null `getCalledFunction()` dereference | `ProfilingPass.cpp`, `InstructionDuplication.cpp`, `CustomTensorOperatorInstSelector.cpp` | Added null checks before name comparison | +| Unchecked null `dyn_cast` results | `Utils.cpp`, multiple | Changed to `cast<>` (asserting) where type is guaranteed by a prior opcode check; added null checks elsewhere | + +Style fixes applied across 26 files: `override` on all overriding methods, +`virtual ~Base() = default` on abstract base classes, `.empty()` replacing +`.size() == 0`, initialized-at-declaration for all local pointers, +`strncpy`/`strncat` replacing unbounded `strcpy`/`strcat`, `const T&` in +range-for loops, and `cl::opt::getValue()` to avoid slicing. + +Infrastructure added: +- `.clang-tidy` — project tidy config with intentionally disabled checks documented +- `lint.sh` — unified C++ and Python lint runner (`./lint.sh --fix` auto-formats) +- `CODING_GUIDELINES.md` — expanded with `override`, variable initialisation, container emptiness, and `cast<>` vs `dyn_cast<>` sections + +### C-8 — Secondary pass on ML/SID code ✅ DONE + +**ML fault injection and instruction duplication passes**: + +| File | Fix | +|------|-----| +| `ProfilingPass.cpp` | `dyn_cast` → `cast` after `isa<>` check in `insertCallForMLFIStats()` | +| `InstructionDuplication.cpp` | `for (auto insVector :` → `for (const auto& insVector :` to avoid copying inner vectors; removed dead `return false;` after exhaustive if/else | + +**Documentation fixes**: + +| File | Fix | +|------|-----| +| `caveats.txt` | LLVM version references updated 15 → 20; duplicate item number fixed | +| `llvm_passes/instruction_duplication/README.md` | Final `opt` invocation updated from legacy PM `-always-inline` to `--passes=always-inline` (legacy PM removed in LLVM 17) | +| `llvm_passes/instruction_duplication/shared_lib/build.sh` | Uses `LLVM_GXX_BIN_DIR` env var to find versioned `clang` (fixes failure on Ubuntu with apt-installed LLVM where only `clang-20` exists) | +| `llvm_passes/instruction_duplication/shared_lib/compile_shrd_lib.sh` | Same fix for `clang++` | + +--- + +## Recommended order of work + +``` +H-1 Install LLVM 17+ and attempt initial build ✅ DONE + └─> C-1 Fix instruction-construction API ✅ DONE + └─> C-2 Fix iterator return-type changes ✅ DONE + └─> C-4 Fix getGlobalList ✅ DONE + └─> C-5 Iterative build-fix loop ✅ DONE + └─> C-6 C++ static analysis and formatting cleanup ✅ DONE + └─> C-7 ML/SID secondary pass, doc fixes ✅ DONE +H-2 Review IRBuilder insertion-point diffs Pending + └─> C-3 Migrate InstructionDuplication pass ✅ DONE +H-3 Validate InstructionDuplication on onnx-mlir IR Pending +H-4 Final test sign-off ✅ DONE (21/21) +``` + +--- + +## Effort summary + +| Task | Owner | Status | Estimated time | +|------|-------|--------|---------------| +| H-1: Install LLVM 20 and update build config | Human | ✅ Done | 2–4 hours | +| H-2: Review IRBuilder insertion-point correctness | Human | Pending | 2–3 hours | +| H-3: Provide onnx-mlir environment (install + compile.sh) | Human | Pending | 1–2 hours | +| H-4: Final test sign-off | Human | ✅ Done (21/21) | — | +| **Total human time remaining** | | | **~3–5 hours** | +| C-1: Deprecated instruction-construction API | Claude Code | ✅ Done | — | +| C-2: Iterator return-type fixes | Claude Code | ✅ Done | — | +| C-3: InstructionDuplication new PM migration | Claude Code | ✅ Done | — | +| C-4: `getGlobalList` fix | Claude Code | ✅ Done | — | +| C-5: Iterative build-fix loop | Claude Code | ✅ Done | — | +| C-6: C++ static analysis and formatting cleanup | Claude Code | ✅ Done | — | +| C-7: ML/SID, doc fixes | Claude Code | ✅ Done | — | +| **Total Claude Code time remaining** | | | **None — all done** | + +Without Claude Code, a human developer would need approximately **2–3 weeks** +of active work. With Claude Code handling the mechanical fixes and the +build-fix loop, human involvement drops to roughly **1.5–2 days** of active +effort, with Claude Code running largely autonomously in between. + +--- + +## What could go wrong + +- **Subtle IRBuilder insertion-point bugs**: compile successfully but produce + miscompiled IR. Mitigated by H-2 (human review) and the existing test suite. +- **New PM semantics differences**: the new PM runs passes in a different order + and does not support inter-pass mutable state in the same way. The + `Controller` singleton used by the FI selectors should be checked for + thread-safety assumptions that the new PM may violate. +- **onnx-mlir compatibility**: onnx-mlir targets a specific LLVM version + internally. If the onnx-mlir build and the LLTFI build target different LLVM + versions, llvm-link may refuse to link the bitcode. This is an environment + concern, not a code concern, but it could block H-3. diff --git a/runtime_lib/CMakeLists.txt b/runtime_lib/CMakeLists.txt index d0aa2534..b5e03861 100644 --- a/runtime_lib/CMakeLists.txt +++ b/runtime_lib/CMakeLists.txt @@ -15,8 +15,6 @@ add_library(llfi-rt SHARED InstTraceLib.c ProfilingLib.cpp Utils.c - #_FIDLSoftwareFaultInjectors.cpp - #_SoftwareFaultInjector.cpp is included in this file ) # For ML backends. This static library is intended to be linked with the ML diff --git a/runtime_lib/CommonFaultInjectors.cpp b/runtime_lib/CommonFaultInjectors.cpp index 3f090911..49682231 100644 --- a/runtime_lib/CommonFaultInjectors.cpp +++ b/runtime_lib/CommonFaultInjectors.cpp @@ -1,30 +1,30 @@ #include "FaultInjector.h" #include "FaultInjectorManager.h" -class BitFlipFI: public HardwareFaultInjector { - public: +class BitFlipFI : public HardwareFaultInjector { +public: virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit, - char *buf) { + char *buf) { unsigned fi_bytepos = fi_bit / 8; unsigned fi_bitpos = fi_bit % 8; buf[fi_bytepos] ^= 0x1 << fi_bitpos; } }; -class StuckAt0FI: public HardwareFaultInjector { - public: +class StuckAt0FI : public HardwareFaultInjector { +public: virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit, - char *buf) { + char *buf) { unsigned fi_bytepos = fi_bit / 8; unsigned fi_bitpos = fi_bit % 8; buf[fi_bytepos] &= ~(0x1 << fi_bitpos); } }; -class StuckAt1FI: public HardwareFaultInjector { - public: +class StuckAt1FI : public HardwareFaultInjector { +public: virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit, - char *buf) { + char *buf) { unsigned fi_bytepos = fi_bit / 8; unsigned fi_bitpos = fi_bit % 8; buf[fi_bytepos] |= 0x1 << fi_bitpos; diff --git a/runtime_lib/FaultInjectionLib.c b/runtime_lib/FaultInjectionLib.c index 22c9f5c2..702b6faf 100755 --- a/runtime_lib/FaultInjectionLib.c +++ b/runtime_lib/FaultInjectionLib.c @@ -61,10 +61,14 @@ void injectFaultImpl(const char *fi_type, long llfi_index, unsigned size, */ void _initRandomSeed() { unsigned int seed; - FILE* urandom = fopen("/dev/urandom", "r"); - fread(&seed, sizeof(int), 1, urandom); - fclose(urandom); - srand(seed); + FILE* urandom = fopen("/dev/urandom", "r"); + if (urandom != NULL) { + fread(&seed, sizeof(int), 1, urandom); + fclose(urandom); + } else { + seed = (unsigned int)time(NULL); + } + srand(seed); } // get whether to make decision based on probability diff --git a/runtime_lib/FaultInjector.h b/runtime_lib/FaultInjector.h index a395fb8e..b87f41a9 100644 --- a/runtime_lib/FaultInjector.h +++ b/runtime_lib/FaultInjector.h @@ -5,25 +5,15 @@ class FaultInjector { // TODO: need to change the interface when we inject multiple bits faults - public: +public: virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit, - char *buf) = 0; - //virtual std::string getFaultInjectorType() = 0; - virtual std::string getFaultInjectorType(){ - return std::string("Unknown"); - } + char *buf) = 0; + // virtual std::string getFaultInjectorType() = 0; + virtual std::string getFaultInjectorType() { return std::string("Unknown"); } }; -class HardwareFaultInjector: public FaultInjector { - std::string getFaultInjectorType(){ - return std::string("HardwareFault"); - } -}; - -class SoftwareFaultInjector: public FaultInjector { - std::string getFaultInjectorType(){ - return std::string("SoftwareFault"); - } +class HardwareFaultInjector : public FaultInjector { + std::string getFaultInjectorType() { return std::string("HardwareFault"); } }; #endif diff --git a/runtime_lib/FaultInjectorManager.cpp b/runtime_lib/FaultInjectorManager.cpp index 92db4b5a..aa39ebd3 100644 --- a/runtime_lib/FaultInjectorManager.cpp +++ b/runtime_lib/FaultInjectorManager.cpp @@ -1,7 +1,9 @@ -#include +#include "FaultInjectorManager.h" + +#include #include + #include "FaultInjector.h" -#include "FaultInjectorManager.h" FaultInjectorManager *FaultInjectorManager::getFaultInjectorManager() { static FaultInjectorManager fi_manager; @@ -10,12 +12,12 @@ FaultInjectorManager *FaultInjectorManager::getFaultInjectorManager() { void FaultInjectorManager::addFaultInjector(const std::string &name, FaultInjector *fi) { - //debug(("enter add fault injector\n")); + // debug(("enter add fault injector\n")); if (type_injector.find(name) == type_injector.end()) { type_injector.insert( - std::pair(name, fi)); + std::pair(name, fi)); } else { - std::cerr << "ERROR: Duplicated fault injector: " << name << std::endl; + fprintf(stderr, "ERROR: Duplicated fault injector: %s\n", name.c_str()); exit(1); } } @@ -24,36 +26,39 @@ FaultInjector *FaultInjectorManager::getFaultInjector(const std::string &name) { if (type_injector.find(name) != type_injector.end()) { return type_injector[name]; } else { - std::cerr << "ERROR: unknown fault injector: " << name << std::endl; + fprintf(stderr, "ERROR: unknown fault injector: %s\n", name.c_str()); exit(1); } } -std::vector FaultInjectorManager::getAllInjectorNames(){ +std::vector FaultInjectorManager::getAllInjectorNames() { std::vector names; - for(std::map::iterator MI = type_injector.begin(); - MI != type_injector.end(); MI++){ - names.push_back(MI->first); + for (std::map::iterator MI = + type_injector.begin(); + MI != type_injector.end(); MI++) { + names.push_back(MI->first); } return names; } -std::vector FaultInjectorManager::getInjectorNamesForType(std::string type_str){ - std::vector names; +std::vector +FaultInjectorManager::getInjectorNamesForType(std::string type_str) { + std::vector names; // std::cout << "start of getInjectorNamesForType()\n"; - for(std::map::iterator MI = type_injector.begin(); - MI != type_injector.end(); MI++){ - // std::cout << "checking:" << MI->first << "pointer addr: " << (MI->second) << "\n"; - // std::cout << " type: " << MI->second->getFaultInjectorType() << "\n"; - if(type_str == MI->second->getFaultInjectorType()) - names.push_back(MI->first); + for (std::map::iterator MI = + type_injector.begin(); + MI != type_injector.end(); MI++) { + // std::cout << "checking:" << MI->first << "pointer addr: " << (MI->second) + // << "\n"; std::cout << " type: " << MI->second->getFaultInjectorType() << + // "\n"; + if (type_str == MI->second->getFaultInjectorType()) + names.push_back(MI->first); } // std::cout << "end of getInjectorNamesForType()\n"; return names; } - -extern "C" void injectFaultImpl(const char *fi_type, long llfi_index, +extern "C" void injectFaultImpl(const char *fi_type, long llfi_index, unsigned size, unsigned fi_bit, char *buf) { FaultInjectorManager *m = FaultInjectorManager::getFaultInjectorManager(); FaultInjector *fi = m->getFaultInjector(fi_type); diff --git a/runtime_lib/FaultInjectorManager.h b/runtime_lib/FaultInjectorManager.h index e09a73d7..76b63365 100644 --- a/runtime_lib/FaultInjectorManager.h +++ b/runtime_lib/FaultInjectorManager.h @@ -1,41 +1,39 @@ #ifndef FAULT_INJECTOR_MANAGER_H #define FAULT_INJECTOR_MANAGER_H +#include "Utils.h" + +#include #include #include +#include #include #include -#include -#include - -#include "Utils.h" - class FaultInjector; class FaultInjectorManager { - public: +public: FaultInjectorManager() {} - public: +public: static FaultInjectorManager *getFaultInjectorManager(); - void addFaultInjector(const std::string &name, - FaultInjector *fi); + void addFaultInjector(const std::string &name, FaultInjector *fi); FaultInjector *getFaultInjector(const std::string &name); std::vector getAllInjectorNames(); std::vector getInjectorNamesForType(std::string type_str); - private: - std::map type_injector; +private: + std::map type_injector; }; struct RegisterFaultInjector { RegisterFaultInjector(const std::string &name, FaultInjector *fi) { - //debug(( "init\n")); + // debug(( "init\n")); FaultInjectorManager *m = FaultInjectorManager::getFaultInjectorManager(); - //debug(("get manager\n")); + // debug(("get manager\n")); m->addFaultInjector(name, fi); - //debug(("finish\n")); + // debug(("finish\n")); } }; diff --git a/runtime_lib/InjectorScanner.cpp b/runtime_lib/InjectorScanner.cpp index d5ed703e..5e192c96 100644 --- a/runtime_lib/InjectorScanner.cpp +++ b/runtime_lib/InjectorScanner.cpp @@ -1,57 +1,47 @@ -#include -#include #include -#include +#include +#include #include +#include + #include "FaultInjector.h" #include "FaultInjectorManager.h" using namespace std; -int main(int argc, char* argv[]) { - string output_file_path(""); - ofstream output_file; - for(int i = 0; i < argc; i++){ - if(string(argv[i]) == string("-o")){ - output_file_path = string(argv[i+1]); - } - // cout << "argv[" << i << "] = " << argv[i] << endl; - } - - if(output_file_path.length() != 0){ - output_file.open(output_file_path.c_str()); - } +int main(int argc, char *argv[]) { + string output_file_path(""); + ofstream output_file; + for (int i = 0; i < argc; i++) { + if (string(argv[i]) == string("-o")) { + output_file_path = string(argv[i + 1]); + } + // cout << "argv[" << i << "] = " << argv[i] << endl; + } - FaultInjectorManager* faultinjectormanager = FaultInjectorManager::getFaultInjectorManager(); - vector hardwarefaultinjectornames = faultinjectormanager->getInjectorNamesForType(string("HardwareFault")); - if(output_file.is_open()){ - output_file << "HardwareFaultInjector:" << endl; - }else{ - cout << "HardwareFaultInjector:" << endl; - } - for(int i = 0; i softwarefaultinjectornames = faultinjectormanager->getInjectorNamesForType(string("SoftwareFault")); - if(output_file.is_open()){ - output_file << "SoftwareFaultInjector:" << endl; - }else{ - cout << "SoftwareFaultInjector:" << endl; - } - for(int i = 0; i hardwarefaultinjectornames = + faultinjectormanager->getInjectorNamesForType(string("HardwareFault")); + if (output_file.is_open()) { + output_file << "HardwareFaultInjector:" << endl; + } else { + cout << "HardwareFaultInjector:" << endl; + } + for (int i = 0; i < hardwarefaultinjectornames.size(); i++) { + if (output_file.is_open()) { + output_file << " - " << hardwarefaultinjectornames[i] << endl; + } else { + cout << " - " << hardwarefaultinjectornames[i] << endl; + } + } - if(output_file.is_open()) output_file.close(); + if (output_file.is_open()) + output_file.close(); - return 0; -} \ No newline at end of file + return 0; +} diff --git a/runtime_lib/Instruction.def b/runtime_lib/Instruction.def index de1ea51f..03f207fa 100644 --- a/runtime_lib/Instruction.def +++ b/runtime_lib/Instruction.def @@ -18,7 +18,7 @@ HANDLE_INST ( 7, Unreachable , UnreachableInst, 1) HANDLE_INST ( 8, CleanupRet , CleanupReturnInst, 1) HANDLE_INST ( 9, CatchRet , CatchReturnInst, 1) HANDLE_INST ( 10, CatchSwitch , CatchSwitchInst, 1) -HANDLE_INST ( 11, CallBr , CallBr, 1) +HANDLE_INST ( 11, CallBr , CallBrInst, 1) // Standard unary operators... HANDLE_INST( 12, FNeg , UnaryOperator, 1) diff --git a/runtime_lib/MLFaultInjectionLib.cpp b/runtime_lib/MLFaultInjectionLib.cpp index 5b4de3a5..11b5bb1a 100644 --- a/runtime_lib/MLFaultInjectionLib.cpp +++ b/runtime_lib/MLFaultInjectionLib.cpp @@ -1,11 +1,11 @@ -#include -#include -#include #include #include +#include +#include #include #include #include +#include #define llu long long unsigned #define OPTION_LENGTH 512 @@ -13,8 +13,8 @@ using namespace std; union inputBuffer { - uint32_t ui; - float f; + uint32_t ui; + float f; }; // DS to hold runtime FI configuration. @@ -44,7 +44,7 @@ struct LLTFIConfig { static LLTFIConfig LLTFI_config; static llu LLTFI_CurrentCycle = 0; static int LLTFI_FICycleIndex = 0; -FILE *injectedfaultsFile = NULL; +FILE *injectedfaultsFile = nullptr; bool LLTFI_doFI = true; // Function to parse the runtime configuration file and @@ -55,29 +55,29 @@ void parseLLTFIConfigFile() { const unsigned CONFIG_LINE_LENGTH = 1024; char line[CONFIG_LINE_LENGTH]; char option[OPTION_LENGTH]; - char *value = NULL; + char *value = nullptr; int fi_next_cycles_index = 0; // Open the runtime configuration file. strncpy(ficonfigfilename, "llfi.config.runtime.txt", 80); FILE *ficonfigFile; ficonfigFile = fopen(ficonfigfilename, "r"); - if (ficonfigFile == NULL) { + if (ficonfigFile == nullptr) { fprintf(stderr, "ERROR: Unable to open llfi config file %s\n", ficonfigfilename); exit(1); } // Iterate through all the options in the runtime config file. - while (fgets(line, CONFIG_LINE_LENGTH, ficonfigFile) != NULL) { + while (fgets(line, CONFIG_LINE_LENGTH, ficonfigFile) != nullptr) { if (line[0] == '#') continue; value = strtok(line, "="); strncpy(option, value, OPTION_LENGTH); - value = strtok(NULL, "="); + value = strtok(nullptr, "="); - //debug(("option, %s, value, %s;", option, value)); + // debug(("option, %s, value, %s;", option, value)); if (strcmp(option, "fi_type") == 0) { strncpy(LLTFI_config.fi_type, value, OPTION_LENGTH); @@ -89,23 +89,25 @@ void parseLLTFIConfigFile() { LLTFI_config.fi_cycle.push_back(atoll(value)); } - else if (strcmp(option, "fi_max_multiple") == 0){ + else if (strcmp(option, "fi_max_multiple") == 0) { LLTFI_config.fi_max_multiple = atoi(value); } - else if (strcmp(option, "fi_next_cycle") == 0){ + else if (strcmp(option, "fi_next_cycle") == 0) { LLTFI_config.fi_cycle.push_back(atoll(value)); } // Parse FI stats for ML applications else if (strcmp(option, "ml_layer_name") == 0) { strncpy(LLTFI_config.fi_ml_layer_name, value, 100); - if (LLTFI_config.fi_ml_layer_name[strlen(LLTFI_config.fi_ml_layer_name) - 1] == '\n') - LLTFI_config.fi_ml_layer_name[strlen(LLTFI_config.fi_ml_layer_name) - 1] = '\0'; + if (LLTFI_config.fi_ml_layer_name[strlen(LLTFI_config.fi_ml_layer_name) - + 1] == '\n') + LLTFI_config + .fi_ml_layer_name[strlen(LLTFI_config.fi_ml_layer_name) - 1] = '\0'; } else if (strcmp(option, "ml_layer_number") == 0) { - LLTFI_config.fi_ml_layer_num = atoll(value); + LLTFI_config.fi_ml_layer_num = atoll(value); } else { @@ -117,11 +119,13 @@ void parseLLTFIConfigFile() { } // Sanity checks - assert(LLTFI_config.fi_type != NULL && "No fault injector selected."); + assert(LLTFI_config.fi_type != nullptr && "No fault injector selected."); assert(LLTFI_config.fi_cycle.size() > 0 && "No fi_cycle selected"); - assert(LLTFI_config.fi_max_multiple > 0 && "invalid fi_max_multiple in config file"); - assert((LLTFI_config.fi_ml_layer_num > 0 || LLTFI_config.fi_ml_layer_num == -1) && - "ml_layer_number should be grater than 0"); + assert(LLTFI_config.fi_max_multiple > 0 && + "invalid fi_max_multiple in config file"); + assert((LLTFI_config.fi_ml_layer_num > 0 || + LLTFI_config.fi_ml_layer_num == -1) && + "ml_layer_number should be grater than 0"); // Sort the fi_cycle vector. sort(LLTFI_config.fi_cycle.begin(), LLTFI_config.fi_cycle.end()); @@ -131,99 +135,99 @@ void parseLLTFIConfigFile() { } extern "C" { - // This function will be called at the beginning of the main function. - void initInjections() { - - srand(time(0)); - parseLLTFIConfigFile(); - - char injectedfaultsfilename[80]; - strncpy(injectedfaultsfilename, "llfi.stat.fi.injectedfaults.txt", 80); - injectedfaultsFile = fopen(injectedfaultsfilename, "a"); - if (injectedfaultsFile == NULL) { - fprintf(stderr, "ERROR: Unable to open injected faults stat file %s\n", - injectedfaultsfilename); - exit(1); - } +// This function will be called at the beginning of the main function. +void initInjections() { + + srand(time(0)); + parseLLTFIConfigFile(); + + char injectedfaultsfilename[80]; + strncpy(injectedfaultsfilename, "llfi.stat.fi.injectedfaults.txt", 80); + injectedfaultsFile = fopen(injectedfaultsfilename, "a"); + if (injectedfaultsFile == nullptr) { + fprintf(stderr, "ERROR: Unable to open injected faults stat file %s\n", + injectedfaultsfilename); + exit(1); } +} - // This function will be called at the end of main() function. - void postInjections() { - fclose(injectedfaultsFile); - LLTFI_doFI = false; - } +// This function will be called at the end of main() function. +void postInjections() { + fclose(injectedfaultsFile); + LLTFI_doFI = false; +} - // Function to check if we should inject fault - bool preFunc(long llfi_index, unsigned opcode, unsigned my_reg_index, +// Function to check if we should inject fault +bool preFunc(long llfi_index, unsigned opcode, unsigned my_reg_index, unsigned total_reg_target_num) { - if (!LLTFI_doFI) return 0; - - LLTFI_CurrentCycle++; + if (!LLTFI_doFI) + return 0; - // If current cycle is the FI cycle. - if (LLTFI_CurrentCycle == LLTFI_config.fi_cycle[LLTFI_FICycleIndex]) { - LLTFI_FICycleIndex++; + LLTFI_CurrentCycle++; - if (LLTFI_FICycleIndex >= LLTFI_config.fi_max_multiple) - LLTFI_doFI = false; + // If current cycle is the FI cycle. + if (LLTFI_CurrentCycle == LLTFI_config.fi_cycle[LLTFI_FICycleIndex]) { + LLTFI_FICycleIndex++; - return true; - } + if (LLTFI_FICycleIndex >= LLTFI_config.fi_max_multiple) + LLTFI_doFI = false; - return 0; + return true; } - // Function to actually inject the fault. - void injectFunc(long llfi_index, unsigned size, char *buf, - unsigned my_reg_index, unsigned reg_pos, char* opcode_str) { + return 0; +} - fprintf(stderr, "MSG: injectFunc() has being called\n"); +// Function to actually inject the fault. +void injectFunc(long llfi_index, unsigned size, char *buf, + unsigned my_reg_index, unsigned reg_pos, char *opcode_str) { - unsigned int fi_bytepos, fi_bitpos; - unsigned char oldbuf; + fprintf(stderr, "MSG: injectFunc() has being called\n"); - fi_bitpos = rand() % size; - fi_bytepos = fi_bitpos / 8; - oldbuf = buf[fi_bytepos]; - inputBuffer oldVal = {.f = *((float*)buf)}; - inputBuffer newVal; + unsigned int fi_bytepos, fi_bitpos; + unsigned char oldbuf; - if (strcmp(LLTFI_config.fi_type, "bitflip") == 0) { + fi_bitpos = rand() % size; + fi_bytepos = fi_bitpos / 8; + oldbuf = buf[fi_bytepos]; + inputBuffer oldVal = {.f = *((float *)buf)}; + inputBuffer newVal; - int8_t val = buf[fi_bytepos]; - int shift = fi_bitpos % 8; + if (strcmp(LLTFI_config.fi_type, "bitflip") == 0) { - val ^= 0x1 << shift; + int8_t val = buf[fi_bytepos]; + int shift = fi_bitpos % 8; - buf[fi_bytepos] = val; - } - else { - assert(false && "Not recognized fi_type"); - } + val ^= 0x1 << shift; - newVal = {.f = *((float*)buf)}; - - if (LLTFI_config.fi_ml_layer_num > 0) - fprintf(injectedfaultsFile, - "FI stat: fi_type=%s, fi_max_multiple=%d, fi_index=%ld, " - "fi_cycle=%lld, fi_reg_index=%u, fi_reg_pos=%u, fi_reg_width=%u, " - "fi_bit=%u, opcode=%s, oldHex=0x%x, newHex=0x%x, oldFloat=%f, " - " newFloat=%f, ml_layer_name=%s, ml_layer_number=%d\n", - LLTFI_config.fi_type, LLTFI_config.fi_max_multiple, - llfi_index, LLTFI_CurrentCycle, my_reg_index, reg_pos, size, - fi_bitpos, opcode_str, oldVal.ui, newVal.ui, oldVal.f, newVal.f, - LLTFI_config.fi_ml_layer_name, LLTFI_config.fi_ml_layer_num); - else - fprintf(injectedfaultsFile, - "FI stat: fi_type=%s, fi_max_multiple=%d, fi_index=%ld, " - "fi_cycle=%lld, fi_reg_index=%u, fi_reg_pos=%u, fi_reg_width=%u, " - "fi_bit=%u, opcode=%s, oldHex=0x%x, newHex=0x%x, oldFloat=%f, " - " newFloat=%f\n", - LLTFI_config.fi_type, LLTFI_config.fi_max_multiple, - llfi_index, LLTFI_CurrentCycle, my_reg_index, reg_pos, size, - fi_bitpos, opcode_str, oldVal.ui, newVal.ui, oldVal.f, newVal.f); - - fflush(injectedfaultsFile); + buf[fi_bytepos] = val; + } else { + assert(false && "Not recognized fi_type"); } + + newVal = {.f = *((float *)buf)}; + + if (LLTFI_config.fi_ml_layer_num > 0) + fprintf(injectedfaultsFile, + "FI stat: fi_type=%s, fi_max_multiple=%d, fi_index=%ld, " + "fi_cycle=%lld, fi_reg_index=%u, fi_reg_pos=%u, fi_reg_width=%u, " + "fi_bit=%u, opcode=%s, oldHex=0x%x, newHex=0x%x, oldFloat=%f, " + " newFloat=%f, ml_layer_name=%s, ml_layer_number=%d\n", + LLTFI_config.fi_type, LLTFI_config.fi_max_multiple, llfi_index, + LLTFI_CurrentCycle, my_reg_index, reg_pos, size, fi_bitpos, + opcode_str, oldVal.ui, newVal.ui, oldVal.f, newVal.f, + LLTFI_config.fi_ml_layer_name, LLTFI_config.fi_ml_layer_num); + else + fprintf(injectedfaultsFile, + "FI stat: fi_type=%s, fi_max_multiple=%d, fi_index=%ld, " + "fi_cycle=%lld, fi_reg_index=%u, fi_reg_pos=%u, fi_reg_width=%u, " + "fi_bit=%u, opcode=%s, oldHex=0x%x, newHex=0x%x, oldFloat=%f, " + " newFloat=%f\n", + LLTFI_config.fi_type, LLTFI_config.fi_max_multiple, llfi_index, + LLTFI_CurrentCycle, my_reg_index, reg_pos, size, fi_bitpos, + opcode_str, oldVal.ui, newVal.ui, oldVal.f, newVal.f); + + fflush(injectedfaultsFile); +} } diff --git a/runtime_lib/ProfilingLib.cpp b/runtime_lib/ProfilingLib.cpp old mode 100755 new mode 100644 index 1eba5609..f3f8f293 --- a/runtime_lib/ProfilingLib.cpp +++ b/runtime_lib/ProfilingLib.cpp @@ -1,11 +1,10 @@ +#include +#include #include #include #include -#include - -#include -#include #include +#include struct layerProfCycle { int layerNo; @@ -20,7 +19,7 @@ struct layerProfCycle { this->cycleEnd = -1; } - void registerCycle (long long unsigned cycle) { + void registerCycle(long long unsigned cycle) { if (this->cycleStart == -1) { this->cycleStart = cycle; } @@ -31,8 +30,7 @@ struct layerProfCycle { static std::vector layerProfileInfo; static int64_t globalLayerNo = 0; -static layerProfCycle *currentLayer = NULL; - +static layerProfCycle *currentLayer = nullptr; // Export these functions in C dilect. extern "C" { @@ -43,20 +41,20 @@ static long long unsigned globalCycle = 0; void lltfiMLLayer(int64_t layerName, int64_t start) { - assert(start == 1 || start == 2 && "Layer start is denoted by 1 and end by 2"); + assert(start == 1 || + start == 2 && "Layer start is denoted by 1 and end by 2"); int64_t *layerNamePtr = &layerName; - char* layerNameStr = (char*)layerNamePtr; + char *layerNameStr = (char *)layerNamePtr; if (start == 1) { /* Layer started. */ globalLayerNo++; currentLayer = new layerProfCycle(globalLayerNo, std::string(layerNameStr)); - } - else { + } else { layerProfileInfo.push_back(*currentLayer); delete currentLayer; - currentLayer = NULL; + currentLayer = nullptr; } } @@ -65,7 +63,7 @@ void doProfiling(int opcode) { "dynamic instruction number too large to be handled by llfi"); opcodecount[opcode]++; globalCycle++; - if (currentLayer != NULL) + if (currentLayer != nullptr) currentLayer->registerCycle(globalCycle); } @@ -73,7 +71,7 @@ void endProfiling() { FILE *profileFile; char profilefilename[80] = "llfi.stat.prof.txt"; profileFile = fopen(profilefilename, "w"); - if (profileFile == NULL) { + if (profileFile == nullptr) { fprintf(stderr, "ERROR: Unable to open profiling result file %s\n", profilefilename); exit(1); @@ -86,10 +84,10 @@ void endProfiling() { long long unsigned total_cycle = 0; for (i = 0; i < 100; ++i) { assert(total_cycle >= 0 && - "total dynamic instruction cycle too large to be handled by llfi"); + "total dynamic instruction cycle too large to be handled by llfi"); if (opcodecount[i] > 0) { assert(opcode_cycle_arr[i] >= 0 && - "opcode does not exist, need to update instructions.def"); + "opcode does not exist, need to update instructions.def"); total_cycle += opcodecount[i] * opcode_cycle_arr[i]; } } @@ -104,6 +102,6 @@ void endProfiling() { layer.layerName.c_str(), layer.cycleStart, layer.cycleEnd); } - fclose(profileFile); + fclose(profileFile); } } // End of extern "C" diff --git a/runtime_lib/Utils.h b/runtime_lib/Utils.h index 1182cf75..b6d127f2 100644 --- a/runtime_lib/Utils.h +++ b/runtime_lib/Utils.h @@ -19,7 +19,9 @@ bool isLittleEndian(); #define DEBUG #ifdef DEBUG -#define debug(x) printf x; fflush(stdout); +#define debug(x) \ + printf x; \ + fflush(stdout); #else #define debug(x) #endif diff --git a/runtime_lib/_FIDLSoftwareFaultInjectors.cpp b/runtime_lib/_FIDLSoftwareFaultInjectors.cpp deleted file mode 100644 index 259ddb0e..00000000 --- a/runtime_lib/_FIDLSoftwareFaultInjectors.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// DO NOT MODIFY -#include "_SoftwareFaultInjectors.cpp" - -/********************* - * DEFAULT INJECTORS * - *********************/ - -/******************** - * CUSTOM INJECTORS * - ********************/ - diff --git a/runtime_lib/_SoftwareFaultInjectors.cpp b/runtime_lib/_SoftwareFaultInjectors.cpp deleted file mode 100644 index d4e36634..00000000 --- a/runtime_lib/_SoftwareFaultInjectors.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include "FaultInjector.h" -#include "FaultInjectorManager.h" -#include -#include -#include -#include -#include -#include -#include - -//2^20 == 32MB -#define MEM_EXHAUSTION_UNIT 33554432 - -class BitCorruptionInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - unsigned int fi_bytepos = fi_bit/8; - unsigned int fi_bitpos = fi_bit%8; - buf[fi_bytepos] ^= 0x1 << fi_bitpos; - return; - } - - static BitCorruptionInjector* getBitCorruptionInjector(){ - static BitCorruptionInjector* injector_ptr = NULL; - if(injector_ptr == NULL){ - injector_ptr = new BitCorruptionInjector(); - return injector_ptr; - }else return injector_ptr; - } -}; - -class MemoryLeakInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - void* fake_p = malloc(1024 * sizeof(char)); - void** newbuf = (void**) buf; - *newbuf = fake_p; - return; - } -}; - -class HangInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - while(1); - return; - } -}; - -class SleepInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - sleep(3); - return; - } -}; - -static RegisterFaultInjector DA("HighFrequentEvent(Timing)", new SleepInjector()); - -class ChangeValueInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - if(is_replace == false){ - int* newbuf = (int*) buf; - *newbuf = *newbuf + add_val; - } - else{ - int* newbuf = (int*) buf; - *newbuf = rep_val; - } - return; - } - - ChangeValueInjector(int val, bool replace):add_val(val), rep_val(val), is_replace(replace){}; - - private: - int add_val; - int rep_val; - bool is_replace; -}; - -class InappropriateCloseInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - if(add_close){ - FILE** newbuf = (FILE**) buf; - fclose(*newbuf); - }else{ - FILE* fp = fopen("fake_file.txt", "w"); - FILE** newbuf = (FILE**) buf; - *newbuf = fp; - } - return; - } - InappropriateCloseInjector(bool addclose):add_close(addclose){}; - - private: - bool add_close; -}; - -class StalePointerInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - void** newbuf = (void**) buf; - free(*newbuf); - } -}; - -class MemoryExhaustionInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - void* p = NULL; - void* left_space = NULL; - do{ - p = malloc(MEM_EXHAUSTION_UNIT); - if(p == NULL) p = malloc(MEM_EXHAUSTION_UNIT>>4); - if(p == NULL) p = malloc(MEM_EXHAUSTION_UNIT>>8); - if(p == NULL) p = malloc(MEM_EXHAUSTION_UNIT>>12); - if(p != NULL) left_space = p; - }while(p != NULL); - if(non_left_space){ - void** newbuf = (void**) buf; - *newbuf = p; - }else{ - void** newbuf = (void**) buf; - *newbuf = left_space; - } - return; - } - - MemoryExhaustionInjector(bool nonleftspace):non_left_space(nonleftspace) {}; - private: - bool non_left_space; -}; - -class WrongFormatInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - switch(*buf){ - case 1: - *buf = 2; break; - case 2: - *buf = 4; break; - case 4: - *buf = 8; break; - case 8: - *buf = 4; break; - case 10: - *buf = 4; break; - default: - break; - } - return; - } -}; - -class PthreadDeadLockInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - pthread_mutex_t mutex1 = PTHREAD_MUTEX_INITIALIZER; - pthread_mutex_lock(&mutex1); - pthread_t thread1 = pthread_t(*buf); - pthread_join(thread1, NULL); - pthread_mutex_lock(&mutex1); - return; - } -}; - -class PthreadThreadKillerInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf){ - pthread_t t = pthread_t(*buf); - sleep(0.02); - pthread_cancel(t); - return; - } -}; - -class PthreadRaceConditionInjector: public SoftwareFaultInjector { - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit,char *buf) { - pthread_mutex_t *fake_mutex = (pthread_mutex_t*) malloc(sizeof(pthread_mutex_t)); - pthread_mutex_init(fake_mutex, NULL); - pthread_mutex_t **newbuf = (pthread_mutex_t**) buf; - *newbuf = fake_mutex; - return; - } -}; - diff --git a/sample_programs/cpp_sample_programs/README.md b/sample_programs/cpp_sample_programs/README.md index 870fddff..83479c7d 100644 --- a/sample_programs/cpp_sample_programs/README.md +++ b/sample_programs/cpp_sample_programs/README.md @@ -6,4 +6,4 @@ export LLFI_BUILD_ROOT= 2. Call the `./compileAndRun.sh` script with appropriate arguments. -Example: For the `factorial` sample application, the first argument is 'factorial', and the second argument as the number to compute the factorial of (e.g., 6). +Example: For the `factorial` sample application, the argument is the number to compute the factorial of (e.g., 6). diff --git a/sample_programs/cpp_sample_programs/bfs/main.cpp b/sample_programs/cpp_sample_programs/bfs/main.cpp old mode 100755 new mode 100644 index e16a3441..4c17878f --- a/sample_programs/cpp_sample_programs/bfs/main.cpp +++ b/sample_programs/cpp_sample_programs/bfs/main.cpp @@ -1,212 +1,206 @@ -/*************************************************************************** - *cr - *cr (C) Copyright 2007 The Board of Trustees of the - *cr University of Illinois - *cr All Rights Reserved - *cr - ***************************************************************************/ -/* - Implementing Breadth first search on CUDA using algorithm given in DAC'10 - paper "An Effective GPU Implementation of Breadth-First Search" - - Copyright (c) 2010 University of Illinois at Urbana-Champaign. - All rights reserved. - - Permission to use, copy, modify and distribute this software and its documentation for - educational purpose is hereby granted without fee, provided that the above copyright - notice and this permission notice appear in all copies of this software and that you do - not sell the software. - - THE SOFTWARE IS PROVIDED "AS IS" AND WITHOUT WARRANTY OF ANY KIND,EXPRESS, IMPLIED OR - OTHERWISE. - - Author: Lijiuan Luo (lluo3@uiuc.edu) -*/ -#include -#include -#include -#include -#include "parboil.h" -#include -#include - -#define MAX_THREADS_PER_BLOCK 512 -#define NUM_SM 30//the number of Streaming Multiprocessors; may change in the future archs -#define NUM_SP 16//8//the number of Streaming processors within each SM; may change in the future - //architectures -#define EXP 4//3// EXP = log(NUM_SP), assuming NUM_SP is still power of 2 in the future architecture - //using EXP and shifting can speed up division operation -#define MOD_OP 8//7 // This variable is also related with NUM_SP; may change in the future architecture; - //using MOD_OP and "bitwise and" can speed up mod operation -#define INF 2147483647//2^31-1 - -#define UP_LIMIT 16677216//2^24 -#define WHITE 16677217 -#define GRAY 16677218 -#define GRAY0 16677219 -#define GRAY1 16677220 -#define BLACK 16677221 -int no_of_nodes; //the number of nodes in the graph -int edge_list_size;//the number of edges in the graph -FILE *fp; - -//typedef int2 Node; -//typedef int2 Edge; - -struct Node{ - int x; - int y; -}; - -struct Edge{ - int x; - int y; -}; -//Somehow "cudaMemset" does not work. So I use cudaMemcpy of constant variables for initialization -const int h_top = 1; -const int zero = 0; - -void runCPU(int argc, char** argv); -void runGPU(int argc, char** argv); -//////////////////////////////////////////////////////////////////// -//the cpu version of bfs for speed comparison -//the text book version ("Introduction to Algorithms") -//////////////////////////////////////////////////////////////////// -void BFS_CPU( Node * h_graph_nodes,Edge * h_graph_edges, - int * color, int * h_cost, int source){ - std::deque wavefront; - wavefront.push_back(source); - color[source] = GRAY; - int index; - while(!wavefront.empty()){ - index = wavefront.front(); - wavefront.pop_front(); - for(int i=h_graph_nodes[index].x; - i<(h_graph_nodes[index].y + - h_graph_nodes[index].x); i++) - { - int id = h_graph_edges[i].x; - if(color[id] == WHITE){ - h_cost[id]=h_cost[index]+1; - wavefront.push_back(id); - color[id] = GRAY; - } - } - color[index] = BLACK; - - - } - -} -//////////////////////////////////////////////////////////////////////////////// -// Main Program -//////////////////////////////////////////////////////////////////////////////// -int main( int argc, char** argv) -{ - no_of_nodes=0; - edge_list_size=0; - runCPU(argc,argv); -// if( cutCheckCmdLineFlag(argc, (const char**)argv, "device") ) -// cutilDeviceInit(argc, argv); -// else - //cudaSetDevice( cutGetMaxGflopsDeviceId() ); -// cudaSetDevice( 1); - - - //CUT_EXIT(argc, argv); -} -/////////////////////////////// -//FUNCTION: only run CPU version -//////////////////////////////////////////// -void runCPU( int argc, char** argv) -{ - - struct pb_Parameters *params; - struct pb_TimerSet timers; - - pb_InitializeTimerSet(&timers); - params = pb_ReadParameters(&argc, argv); - if ((params->inpFiles[0] == NULL) || (params->inpFiles[1] != NULL)) - { - fprintf(stderr, "Expecting one input filename\n"); - exit(-1); - } - - pb_SwitchToTimer(&timers, pb_TimerID_IO); - //printf("Reading File\n"); - //Read in Graph from a file - fp = fopen(params->inpFiles[0],"r"); - if(!fp) - { - printf("Error Reading graph file\n"); - return; - } - - int source; - - fscanf(fp,"%d",&no_of_nodes); - // allocate host memory - Node* h_graph_nodes = (Node*) malloc(sizeof(Node)*no_of_nodes); - int *color = (int*) malloc(sizeof(int)*no_of_nodes); - int start, edgeno; - // initalize the memory - for( unsigned int i = 0; i < no_of_nodes; i++) - { - fscanf(fp,"%d %d",&start,&edgeno); - h_graph_nodes[i].x = start; - h_graph_nodes[i].y = edgeno; - color[i]=WHITE; - } - //read the source node from the file - fscanf(fp,"%d",&source); - fscanf(fp,"%d",&edge_list_size); - int id,cost; - Edge* h_graph_edges = (Edge*) malloc(sizeof(Edge)*edge_list_size); - for(int i=0; i < edge_list_size ; i++) - { - fscanf(fp,"%d",&id); - fscanf(fp,"%d",&cost); - h_graph_edges[i].x = id; - h_graph_edges[i].y = cost; - } - if(fp) - fclose(fp); - - //printf("Read File\n"); - - // allocate mem for the result on host side - int* h_cost = (int*) malloc( sizeof(int)*no_of_nodes); - for(int i = 0; i < no_of_nodes; i++){ - h_cost[i] = INF; - } - h_cost[source] = 0; - //printf("start cpu version\n"); - unsigned int cpu_timer = 0; - pb_SwitchToTimer(&timers, pb_TimerID_COMPUTE); - BFS_CPU( h_graph_nodes, h_graph_edges, color, h_cost, source - ); - pb_SwitchToTimer(&timers, pb_TimerID_IO); - if(params->outFile!=NULL) - { - //printf("Result stored in %s\n", params->outFile); - FILE *fp = fopen(params->outFile,"w"); - fprintf(fp,"%d\n", no_of_nodes); - for(int i=0;i +#include +#include +#include +#include +#include + +#include "parboil.h" + +#define MAX_THREADS_PER_BLOCK 512 +#define NUM_SM \ + 30 // the number of Streaming Multiprocessors; may change in the future archs +#define NUM_SP \ + 16 // 8//the number of Streaming processors within each SM; may change in the + // future +// architectures +#define EXP \ + 4 // 3// EXP = log(NUM_SP), assuming NUM_SP is still power of 2 in the future + // architecture +// using EXP and shifting can speed up division operation +#define MOD_OP \ + 8 // 7 // This variable is also related with NUM_SP; may change in the future + // architecture; +// using MOD_OP and "bitwise and" can speed up mod operation +#define INF 2147483647 // 2^31-1 + +#define UP_LIMIT 16677216 // 2^24 +#define WHITE 16677217 +#define GRAY 16677218 +#define GRAY0 16677219 +#define GRAY1 16677220 +#define BLACK 16677221 +int no_of_nodes; // the number of nodes in the graph +int edge_list_size; // the number of edges in the graph +FILE *fp; + +// typedef int2 Node; +// typedef int2 Edge; + +struct Node { + int x; + int y; +}; + +struct Edge { + int x; + int y; +}; +// Somehow "cudaMemset" does not work. So I use cudaMemcpy of constant variables +// for initialization +const int h_top = 1; +const int zero = 0; + +void runCPU(int argc, char **argv); +void runGPU(int argc, char **argv); +//////////////////////////////////////////////////////////////////// +// the cpu version of bfs for speed comparison +// the text book version ("Introduction to Algorithms") +//////////////////////////////////////////////////////////////////// +void BFS_CPU(Node *h_graph_nodes, Edge *h_graph_edges, int *color, int *h_cost, + int source) { + std::deque wavefront; + wavefront.push_back(source); + color[source] = GRAY; + int index; + while (!wavefront.empty()) { + index = wavefront.front(); + wavefront.pop_front(); + for (int i = h_graph_nodes[index].x; + i < (h_graph_nodes[index].y + h_graph_nodes[index].x); i++) { + int id = h_graph_edges[i].x; + if (color[id] == WHITE) { + h_cost[id] = h_cost[index] + 1; + wavefront.push_back(id); + color[id] = GRAY; + } + } + color[index] = BLACK; + } +} +//////////////////////////////////////////////////////////////////////////////// +// Main Program +//////////////////////////////////////////////////////////////////////////////// +int main(int argc, char **argv) { + no_of_nodes = 0; + edge_list_size = 0; + runCPU(argc, argv); + // if( cutCheckCmdLineFlag(argc, (const char**)argv, "device") ) + // cutilDeviceInit(argc, argv); + // else + // cudaSetDevice( cutGetMaxGflopsDeviceId() ); + // cudaSetDevice( 1); + + // CUT_EXIT(argc, argv); +} +/////////////////////////////// +// FUNCTION: only run CPU version +//////////////////////////////////////////// +void runCPU(int argc, char **argv) { + + struct pb_Parameters *params; + struct pb_TimerSet timers; + + pb_InitializeTimerSet(&timers); + params = pb_ReadParameters(&argc, argv); + if ((params->inpFiles[0] == NULL) || (params->inpFiles[1] != NULL)) { + fprintf(stderr, "Expecting one input filename\n"); + exit(-1); + } + + pb_SwitchToTimer(&timers, pb_TimerID_IO); + // printf("Reading File\n"); + // Read in Graph from a file + fp = fopen(params->inpFiles[0], "r"); + if (!fp) { + printf("Error Reading graph file\n"); + return; + } + + int source; + + fscanf(fp, "%d", &no_of_nodes); + // allocate host memory + Node *h_graph_nodes = (Node *)malloc(sizeof(Node) * no_of_nodes); + int *color = (int *)malloc(sizeof(int) * no_of_nodes); + int start, edgeno; + // initalize the memory + for (unsigned int i = 0; i < no_of_nodes; i++) { + fscanf(fp, "%d %d", &start, &edgeno); + h_graph_nodes[i].x = start; + h_graph_nodes[i].y = edgeno; + color[i] = WHITE; + } + // read the source node from the file + fscanf(fp, "%d", &source); + fscanf(fp, "%d", &edge_list_size); + int id, cost; + Edge *h_graph_edges = (Edge *)malloc(sizeof(Edge) * edge_list_size); + for (int i = 0; i < edge_list_size; i++) { + fscanf(fp, "%d", &id); + fscanf(fp, "%d", &cost); + h_graph_edges[i].x = id; + h_graph_edges[i].y = cost; + } + if (fp) + fclose(fp); + + // printf("Read File\n"); + + // allocate mem for the result on host side + int *h_cost = (int *)malloc(sizeof(int) * no_of_nodes); + for (int i = 0; i < no_of_nodes; i++) { + h_cost[i] = INF; + } + h_cost[source] = 0; + // printf("start cpu version\n"); + unsigned int cpu_timer = 0; + pb_SwitchToTimer(&timers, pb_TimerID_COMPUTE); + BFS_CPU(h_graph_nodes, h_graph_edges, color, h_cost, source); + pb_SwitchToTimer(&timers, pb_TimerID_IO); + if (params->outFile != NULL) { + // printf("Result stored in %s\n", params->outFile); + FILE *fp = fopen(params->outFile, "w"); + fprintf(fp, "%d\n", no_of_nodes); + for (int i = 0; i < no_of_nodes; i++) + fprintf(fp, "%d %d\n", i, h_cost[i]); + fclose(fp); + } + + pb_SwitchToTimer(&timers, pb_TimerID_COMPUTE); + // cleanup memory + free(h_graph_nodes); + free(h_graph_edges); + free(color); + free(h_cost); + pb_SwitchToTimer(&timers, pb_TimerID_NONE); + pb_PrintTimerSet(&timers); + pb_FreeParameters(params); +} +/////////////////////////////// +// FUNCTION:only run GPU version +//////////////////////////////////////////// diff --git a/sample_programs/cpp_sample_programs/bfs/parboil.cpp b/sample_programs/cpp_sample_programs/bfs/parboil.cpp old mode 100755 new mode 100644 index 0884419a..5946852a --- a/sample_programs/cpp_sample_programs/bfs/parboil.cpp +++ b/sample_programs/cpp_sample_programs/bfs/parboil.cpp @@ -3,38 +3,39 @@ */ #include "parboil.h" + +#include #include #include -#include #if _POSIX_VERSION >= 200112L -# include +#include #endif /* Free an array of owned strings. */ -static void -free_string_array(char **string_array) -{ +static void free_string_array(char **string_array) { char **p; - if (!string_array) return; - for (p = string_array; *p; p++) free(*p); + if (!string_array) + return; + for (p = string_array; *p; p++) + free(*p); free(string_array); } /* Parse a comma-delimited list of strings into an * array of strings. */ -static char ** -read_string_array(char *in) -{ +static char **read_string_array(char *in) { char **ret; int i; - int count; /* Number of items in the input */ - char *substring; /* Current substring within 'in' */ + int count; /* Number of items in the input */ + char *substring; /* Current substring within 'in' */ /* Count the number of items in the string */ count = 1; - for (i = 0; in[i]; i++) if (in[i] == ',') count++; + for (i = 0; in[i]; i++) + if (in[i] == ',') + count++; /* Allocate storage */ ret = (char **)malloc((count + 1) * sizeof(char *)); @@ -47,8 +48,8 @@ read_string_array(char *in) /* Find length of substring */ for (substring_end = substring; - (*substring_end != ',') && (*substring_end != 0); - substring_end++); + (*substring_end != ',') && (*substring_end != 0); substring_end++) + ; substring_length = substring_end - substring; @@ -60,41 +61,35 @@ read_string_array(char *in) /* go to next substring */ substring = substring_end + 1; } - ret[i] = NULL; /* Write the sentinel value */ + ret[i] = NULL; /* Write the sentinel value */ return ret; } struct argparse { - int argc; /* Number of arguments. Mutable. */ - char **argv; /* Argument values. Immutable. */ + int argc; /* Number of arguments. Mutable. */ + char **argv; /* Argument values. Immutable. */ - int argn; /* Current argument number. */ - char **argv_get; /* Argument value being read. */ - char **argv_put; /* Argument value being written. - * argv_put <= argv_get. */ + int argn; /* Current argument number. */ + char **argv_get; /* Argument value being read. */ + char **argv_put; /* Argument value being written. + * argv_put <= argv_get. */ }; -static void -initialize_argparse(struct argparse *ap, int argc, char **argv) -{ +static void initialize_argparse(struct argparse *ap, int argc, char **argv) { ap->argc = argc; ap->argn = 0; ap->argv_get = ap->argv_put = ap->argv = argv; } -static void -finalize_argparse(struct argparse *ap) -{ +static void finalize_argparse(struct argparse *ap) { /* Move the remaining arguments */ - for(; ap->argn < ap->argc; ap->argn++) + for (; ap->argn < ap->argc; ap->argn++) *ap->argv_put++ = *ap->argv_get++; } /* Delete the current argument. */ -static void -delete_argument(struct argparse *ap) -{ +static void delete_argument(struct argparse *ap) { if (ap->argn >= ap->argc) { fprintf(stderr, "delete_argument\n"); } @@ -104,9 +99,7 @@ delete_argument(struct argparse *ap) /* Go to the next argument. Also, move the current argument to its * final location in argv. */ -static void -next_argument(struct argparse *ap) -{ +static void next_argument(struct argparse *ap) { if (ap->argn >= ap->argc) { fprintf(stderr, "next_argument\n"); } @@ -115,33 +108,25 @@ next_argument(struct argparse *ap) ap->argn++; } -static int -is_end_of_arguments(struct argparse *ap) -{ +static int is_end_of_arguments(struct argparse *ap) { return ap->argn == ap->argc; } -static char * -get_argument(struct argparse *ap) -{ +static char *get_argument(struct argparse *ap) { return *ap->argv_get; } -static char * -consume_argument(struct argparse *ap) -{ +static char *consume_argument(struct argparse *ap) { char *ret = get_argument(ap); delete_argument(ap); return ret; } -struct pb_Parameters * -pb_ReadParameters(int *_argc, char **argv) -{ +struct pb_Parameters *pb_ReadParameters(int *_argc, char **argv) { char *err_message; struct argparse ap; struct pb_Parameters *ret = - (struct pb_Parameters *)malloc(sizeof(struct pb_Parameters)); + (struct pb_Parameters *)malloc(sizeof(struct pb_Parameters)); /* Initialize the parameters structure */ ret->outFile = NULL; @@ -150,59 +135,54 @@ pb_ReadParameters(int *_argc, char **argv) /* Each argument */ initialize_argparse(&ap, *_argc, argv); - while(!is_end_of_arguments(&ap)) { + while (!is_end_of_arguments(&ap)) { char *arg = get_argument(&ap); /* Single-character flag */ if ((arg[0] == '-') && (arg[1] != 0) && (arg[2] == 0)) { - delete_argument(&ap); /* This argument is consumed here */ - - switch(arg[1]) { - case 'o': /* Output file name */ - if (is_end_of_arguments(&ap)) - { - err_message = "Expecting file name after '-o'\n"; - goto error; - } - free(ret->outFile); - ret->outFile = strdup(consume_argument(&ap)); - break; - case 'i': /* Input file name */ - if (is_end_of_arguments(&ap)) - { - err_message = "Expecting file name after '-i'\n"; - goto error; - } - ret->inpFiles = read_string_array(consume_argument(&ap)); - break; - case '-': /* End of options */ - goto end_of_options; + delete_argument(&ap); /* This argument is consumed here */ + + switch (arg[1]) { + case 'o': /* Output file name */ + if (is_end_of_arguments(&ap)) { + err_message = "Expecting file name after '-o'\n"; + goto error; + } + free(ret->outFile); + ret->outFile = strdup(consume_argument(&ap)); + break; + case 'i': /* Input file name */ + if (is_end_of_arguments(&ap)) { + err_message = "Expecting file name after '-i'\n"; + goto error; + } + ret->inpFiles = read_string_array(consume_argument(&ap)); + break; + case '-': /* End of options */ + goto end_of_options; default: - err_message = "Unexpected command-line parameter\n"; - goto error; + err_message = "Unexpected command-line parameter\n"; + goto error; } - } - else { + } else { /* Other parameters are ignored */ next_argument(&ap); } } /* end for each argument */ - end_of_options: - *_argc = ap.argc; /* Save the modified argc value */ +end_of_options: + *_argc = ap.argc; /* Save the modified argc value */ finalize_argparse(&ap); return ret; - error: +error: fputs(err_message, stderr); pb_FreeParameters(ret); return NULL; } -void -pb_FreeParameters(struct pb_Parameters *p) -{ +void pb_FreeParameters(struct pb_Parameters *p) { char **cpp; free(p->outFile); @@ -210,56 +190,47 @@ pb_FreeParameters(struct pb_Parameters *p) free(p); } -int -pb_Parameters_CountInputs(struct pb_Parameters *p) -{ +int pb_Parameters_CountInputs(struct pb_Parameters *p) { int n; - for (n = 0; p->inpFiles[n]; n++); + for (n = 0; p->inpFiles[n]; n++) + ; return n; } /*****************************************************************************/ /* Timer routines */ -static void -accumulate_time(pb_Timestamp *accum, - pb_Timestamp start, - pb_Timestamp end) -{ +static void accumulate_time(pb_Timestamp *accum, pb_Timestamp start, + pb_Timestamp end) { #if _POSIX_VERSION >= 200112L *accum += end - start; #else -# error "Timestamps not implemented for this system" +#error "Timestamps not implemented for this system" #endif } #if _POSIX_VERSION >= 200112L -static pb_Timestamp get_time() -{ +static pb_Timestamp get_time() { struct timeval tv; gettimeofday(&tv, NULL); - return (pb_Timestamp) (tv.tv_sec * 1000000LL + tv.tv_usec); + return (pb_Timestamp)(tv.tv_sec * 1000000LL + tv.tv_usec); } #else -# error "no supported time libraries are available on this platform" +#error "no supported time libraries are available on this platform" #endif -void -pb_ResetTimer(struct pb_Timer *timer) -{ +void pb_ResetTimer(struct pb_Timer *timer) { timer->state = pb_Timer_STOPPED; #if _POSIX_VERSION >= 200112L timer->elapsed = 0; #else -# error "pb_ResetTimer: not implemented for this system" +#error "pb_ResetTimer: not implemented for this system" #endif } -void -pb_StartTimer(struct pb_Timer *timer) -{ +void pb_StartTimer(struct pb_Timer *timer) { if (timer->state != pb_Timer_STOPPED) { fputs("Ignoring attempt to start a running timer\n", stderr); return; @@ -274,13 +245,12 @@ pb_StartTimer(struct pb_Timer *timer) timer->init = tv.tv_sec * 1000000LL + tv.tv_usec; } #else -# error "pb_StartTimer: not implemented for this system" +#error "pb_StartTimer: not implemented for this system" #endif } -void -pb_StartTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) -{ +void pb_StartTimerAndSubTimer(struct pb_Timer *timer, + struct pb_Timer *subtimer) { unsigned int numNotStopped = 0x3; // 11 if (timer->state != pb_Timer_STOPPED) { fputs("Warning: Timer was not stopped\n", stderr); @@ -302,24 +272,21 @@ pb_StartTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) { struct timeval tv; gettimeofday(&tv, NULL); - + if (numNotStopped & 0x2) { timer->init = tv.tv_sec * 1000000LL + tv.tv_usec; } - + if (numNotStopped & 0x1) { subtimer->init = tv.tv_sec * 1000000LL + tv.tv_usec; } } #else -# error "pb_StartTimer: not implemented for this system" +#error "pb_StartTimer: not implemented for this system" #endif - } -void -pb_StopTimer(struct pb_Timer *timer) -{ +void pb_StopTimer(struct pb_Timer *timer) { pb_Timestamp fini; @@ -337,15 +304,15 @@ pb_StopTimer(struct pb_Timer *timer) fini = tv.tv_sec * 1000000LL + tv.tv_usec; } #else -# error "pb_StopTimer: not implemented for this system" +#error "pb_StopTimer: not implemented for this system" #endif accumulate_time(&timer->elapsed, timer->init, fini); timer->init = fini; - } -void pb_StopTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) { +void pb_StopTimerAndSubTimer(struct pb_Timer *timer, + struct pb_Timer *subtimer) { pb_Timestamp fini; @@ -363,7 +330,6 @@ void pb_StopTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) return; } - timer->state = pb_Timer_STOPPED; subtimer->state = pb_Timer_STOPPED; @@ -374,25 +340,22 @@ void pb_StopTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) fini = tv.tv_sec * 1000000LL + tv.tv_usec; } #else -# error "pb_StopTimer: not implemented for this system" +#error "pb_StopTimer: not implemented for this system" #endif if (numNotRunning & 0x2) { accumulate_time(&timer->elapsed, timer->init, fini); timer->init = fini; } - + if (numNotRunning & 0x1) { accumulate_time(&subtimer->elapsed, subtimer->init, fini); subtimer->init = fini; } - } /* Get the elapsed time in seconds. */ -double -pb_GetElapsedTime(struct pb_Timer *timer) -{ +double pb_GetElapsedTime(struct pb_Timer *timer) { double ret; if (timer->state != pb_Timer_STOPPED) { @@ -402,22 +365,19 @@ pb_GetElapsedTime(struct pb_Timer *timer) #if _POSIX_VERSION >= 200112L ret = timer->elapsed / 1e6; #else -# error "pb_GetElapsedTime: not implemented for this system" +#error "pb_GetElapsedTime: not implemented for this system" #endif return ret; } -void -pb_InitializeTimerSet(struct pb_TimerSet *timers) -{ +void pb_InitializeTimerSet(struct pb_TimerSet *timers) { int n; - + timers->wall_begin = get_time(); timers->current = pb_TimerID_NONE; timers->async_markers = NULL; - for (n = 0; n < pb_TimerID_LAST; n++) { pb_ResetTimer(&timers->timers[n]); @@ -425,24 +385,24 @@ pb_InitializeTimerSet(struct pb_TimerSet *timers) } } -void -pb_AddSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID pb_Category) { - - struct pb_SubTimer *subtimer = (struct pb_SubTimer *) malloc - (sizeof(struct pb_SubTimer)); - +void pb_AddSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID pb_Category) { + + struct pb_SubTimer *subtimer = + (struct pb_SubTimer *)malloc(sizeof(struct pb_SubTimer)); + int len = strlen(label); - - subtimer->label = (char *) malloc (sizeof(char)*(len+1)); + + subtimer->label = (char *)malloc(sizeof(char) * (len + 1)); sprintf(subtimer->label, "%s\0", label); - + pb_ResetTimer(&subtimer->timer); subtimer->next = NULL; - + struct pb_SubTimerList *subtimerlist = timers->sub_timer_list[pb_Category]; if (subtimerlist == NULL) { - subtimerlist = (struct pb_SubTimerList *) malloc - (sizeof(struct pb_SubTimerList)); + subtimerlist = + (struct pb_SubTimerList *)malloc(sizeof(struct pb_SubTimerList)); subtimerlist->subtimer_list = subtimer; timers->sub_timer_list[pb_Category] = subtimerlist; } else { @@ -453,28 +413,30 @@ pb_AddSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID pb_Categ } element->next = subtimer; } - } -void -pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID category) -{ +void pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID category) { + + // switchToSub( NULL, NONE + // switchToSub( NULL, some + // switchToSub( some, some + // switchToSub( some, NONE -- tries to find "some" in NONE's sublist, which + // won't be printed -// switchToSub( NULL, NONE -// switchToSub( NULL, some -// switchToSub( some, some -// switchToSub( some, NONE -- tries to find "some" in NONE's sublist, which won't be printed - struct pb_Timer *topLevelToStop = NULL; if (timers->current != category && timers->current != pb_TimerID_NONE) { - // Switching to subtimer in a different category needs to stop the top-level current, different categoried timer. - // NONE shouldn't have a timer associated with it, so exclude from branch + // Switching to subtimer in a different category needs to stop the top-level + // current, different categoried timer. NONE shouldn't have a timer + // associated with it, so exclude from branch topLevelToStop = &timers->timers[timers->current]; - } + } + + struct pb_SubTimerList *subtimerlist = + timers->sub_timer_list[timers->current]; + struct pb_SubTimer *curr = + (subtimerlist == NULL) ? NULL : subtimerlist->current; - struct pb_SubTimerList *subtimerlist = timers->sub_timer_list[timers->current]; - struct pb_SubTimer *curr = (subtimerlist == NULL) ? NULL : subtimerlist->current; - if (timers->current != pb_TimerID_NONE) { if (curr != NULL && topLevelToStop != NULL) { pb_StopTimerAndSubTimer(topLevelToStop, &curr->timer); @@ -484,11 +446,11 @@ pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID cat pb_StopTimer(topLevelToStop); } } - + subtimerlist = timers->sub_timer_list[category]; struct pb_SubTimer *subtimer = NULL; - - if (label != NULL) { + + if (label != NULL) { subtimer = subtimerlist->subtimer_list; while (subtimer != NULL) { if (strcmp(subtimer->label, label) == 0) { @@ -497,46 +459,45 @@ pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID cat subtimer = subtimer->next; } } - } - + } + if (category != pb_TimerID_NONE) { - + if (subtimerlist != NULL) { subtimerlist->current = subtimer; } - + if (category != timers->current && subtimer != NULL) { pb_StartTimerAndSubTimer(&timers->timers[category], &subtimer->timer); } else if (subtimer != NULL) { // Same category, different non-NULL subtimer pb_StartTimer(&subtimer->timer); - } else{ - // Different category, but no subtimer (not found or specified as NULL) -- unprefered way of setting topLevel timer + } else { + // Different category, but no subtimer (not found or specified as NULL) -- + // unprefered way of setting topLevel timer pb_StartTimer(&timers->timers[category]); } - } - + } + timers->current = category; - } -void -pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer) -{ +void pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer) { /* Stop the currently running timer */ if (timers->current != pb_TimerID_NONE) { struct pb_SubTimer *currSubTimer = NULL; - struct pb_SubTimerList *subtimerlist = timers->sub_timer_list[timers->current]; - - if ( subtimerlist != NULL) { + struct pb_SubTimerList *subtimerlist = + timers->sub_timer_list[timers->current]; + + if (subtimerlist != NULL) { currSubTimer = timers->sub_timer_list[timers->current]->current; } - if ( currSubTimer!= NULL) { - pb_StopTimerAndSubTimer(&timers->timers[timers->current], &currSubTimer->timer); + if (currSubTimer != NULL) { + pb_StopTimerAndSubTimer(&timers->timers[timers->current], + &currSubTimer->timer); } else { pb_StopTimer(&timers->timers[timers->current]); } - } timers->current = timer; @@ -546,30 +507,29 @@ pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer) } } -void -pb_PrintTimerSet(struct pb_TimerSet *timers) -{ +void pb_PrintTimerSet(struct pb_TimerSet *timers) { pb_Timestamp wall_end = get_time(); struct pb_Timer *t = timers->timers; - struct pb_SubTimer* sub = NULL; - + struct pb_SubTimer *sub = NULL; + int maxSubLength; - - const char *categories[] = { - "IO", "Kernel", "Copy", "Driver", "Copy Async", "Compute" - }; - + + const char *categories[] = {"IO", "Kernel", "Copy", + "Driver", "Copy Async", "Compute"}; + const int maxCategoryLength = 10; - + int i; - for(i = 1; i < pb_TimerID_LAST-1; ++i) { // exclude NONE and OVRELAP from this format - if(pb_GetElapsedTime(&t[i]) != 0) { - + for (i = 1; i < pb_TimerID_LAST - 1; + ++i) { // exclude NONE and OVRELAP from this format + if (pb_GetElapsedTime(&t[i]) != 0) { + // Print Category Timer - printf("%-*s: %f\n", maxCategoryLength, categories[i-1], pb_GetElapsedTime(&t[i])); - + printf("%-*s: %f\n", maxCategoryLength, categories[i - 1], + pb_GetElapsedTime(&t[i])); + if (timers->sub_timer_list[i] != NULL) { sub = timers->sub_timer_list[i]->subtimer_list; maxSubLength = 0; @@ -580,44 +540,44 @@ pb_PrintTimerSet(struct pb_TimerSet *timers) } sub = sub->next; } - + // Fit to Categories if (maxSubLength <= maxCategoryLength) { - maxSubLength = maxCategoryLength; + maxSubLength = maxCategoryLength; } - + sub = timers->sub_timer_list[i]->subtimer_list; - + // Print SubTimers while (sub != NULL) { - printf(" -%-*s: %f\n", maxSubLength, sub->label, pb_GetElapsedTime(&sub->timer)); + printf(" -%-*s: %f\n", maxSubLength, sub->label, + pb_GetElapsedTime(&sub->timer)); sub = sub->next; } } } } - - if(pb_GetElapsedTime(&t[pb_TimerID_OVERLAP]) != 0) - printf("CPU/Kernel Overlap: %f\n", pb_GetElapsedTime(&t[pb_TimerID_OVERLAP])); - - float walltime = (wall_end - timers->wall_begin)/ 1e6; - printf("Timer Wall Time: %f\n", walltime); - + + if (pb_GetElapsedTime(&t[pb_TimerID_OVERLAP]) != 0) + printf("CPU/Kernel Overlap: %f\n", + pb_GetElapsedTime(&t[pb_TimerID_OVERLAP])); + + float walltime = (wall_end - timers->wall_begin) / 1e6; + printf("Timer Wall Time: %f\n", walltime); } -void pb_DestroyTimerSet(struct pb_TimerSet * timers) -{ +void pb_DestroyTimerSet(struct pb_TimerSet *timers) { /* clean up all of the async event markers */ - struct pb_async_time_marker_list ** event = &(timers->async_markers); - while( *event != NULL) { - struct pb_async_time_marker_list ** next = &((*event)->next); + struct pb_async_time_marker_list **event = &(timers->async_markers); + while (*event != NULL) { + struct pb_async_time_marker_list **next = &((*event)->next); free(*event); (*event) = NULL; event = next; } - + int i = 0; - for(i = 0; i < pb_TimerID_LAST; ++i) { + for (i = 0; i < pb_TimerID_LAST; ++i) { if (timers->sub_timer_list[i] != NULL) { struct pb_SubTimer *subtimer = timers->sub_timer_list[i]->subtimer_list; struct pb_SubTimer *prev = NULL; @@ -631,5 +591,3 @@ void pb_DestroyTimerSet(struct pb_TimerSet * timers) } } } - - diff --git a/sample_programs/cpp_sample_programs/bfs/parboil.h b/sample_programs/cpp_sample_programs/bfs/parboil.h old mode 100755 new mode 100644 index bc4891f2..f2222c67 --- a/sample_programs/cpp_sample_programs/bfs/parboil.h +++ b/sample_programs/cpp_sample_programs/bfs/parboil.h @@ -12,13 +12,13 @@ extern "C" { /* Command line parameters for benchmarks */ struct pb_Parameters { - char *outFile; /* If not NULL, the raw output of the - * computation should be saved to this - * file. The string is owned. */ - char **inpFiles; /* A NULL-terminated array of strings - * holding the input file(s) for the - * computation. The array and strings - * are owned. */ + char *outFile; /* If not NULL, the raw output of the + * computation should be saved to this + * file. The string is owned. */ + char **inpFiles; /* A NULL-terminated array of strings + * holding the input file(s) for the + * computation. The array and strings + * are owned. */ }; /* Read command-line parameters. @@ -30,24 +30,21 @@ struct pb_Parameters { * If there is an error, then an error message is printed on stderr * and NULL is returned. */ -struct pb_Parameters * -pb_ReadParameters(int *_argc, char **argv); +struct pb_Parameters *pb_ReadParameters(int *_argc, char **argv); /* Free an instance of struct pb_Parameters. */ -void -pb_FreeParameters(struct pb_Parameters *p); +void pb_FreeParameters(struct pb_Parameters *p); /* Count the number of input files in a pb_Parameters instance. */ -int -pb_Parameters_CountInputs(struct pb_Parameters *p); +int pb_Parameters_CountInputs(struct pb_Parameters *p); /* A time or duration. */ #if _POSIX_VERSION >= 200112L typedef unsigned long long pb_Timestamp; /* time in microseconds */ #else -# error "Timestamps not implemented" +#error "Timestamps not implemented" #endif enum pb_TimerState { @@ -57,68 +54,64 @@ enum pb_TimerState { struct pb_Timer { enum pb_TimerState state; - pb_Timestamp elapsed; /* Amount of time elapsed so far */ - pb_Timestamp init; /* Beginning of the current time interval, - * if state is RUNNING. End of the last - * recorded time interfal otherwise. */ + pb_Timestamp elapsed; /* Amount of time elapsed so far */ + pb_Timestamp init; /* Beginning of the current time interval, + * if state is RUNNING. End of the last + * recorded time interfal otherwise. */ }; /* Reset a timer. * Use this to initialize a timer or to clear * its elapsed time. The reset timer is stopped. */ -void -pb_ResetTimer(struct pb_Timer *timer); +void pb_ResetTimer(struct pb_Timer *timer); /* Start a timer. The timer is set to RUNNING mode and * time elapsed while the timer is running is added to * the timer. * The timer should not already be running. */ -void -pb_StartTimer(struct pb_Timer *timer); +void pb_StartTimer(struct pb_Timer *timer); /* Stop a timer. * This stops adding elapsed time to the timer. * The timer should not already be stopped. */ -void -pb_StopTimer(struct pb_Timer *timer); +void pb_StopTimer(struct pb_Timer *timer); /* Get the elapsed time in seconds. */ -double -pb_GetElapsedTime(struct pb_Timer *timer); +double pb_GetElapsedTime(struct pb_Timer *timer); /* Execution time is assigned to one of these categories. */ enum pb_TimerID { pb_TimerID_NONE = 0, - pb_TimerID_IO, /* Time spent in input/output */ - pb_TimerID_KERNEL, /* Time spent computing on the device, - * recorded asynchronously */ - pb_TimerID_COPY, /* Time spent synchronously moving data - * to/from device and allocating/freeing - * memory on the device */ - pb_TimerID_DRIVER, /* Time spent in the host interacting with the - * driver, primarily for recording the time - * spent queueing asynchronous operations */ - pb_TimerID_COPY_ASYNC, /* Time spent in asynchronous transfers */ - pb_TimerID_COMPUTE, /* Time for all program execution other - * than parsing command line arguments, - * I/O, kernel, and copy */ - pb_TimerID_OVERLAP, /* Time double-counted in asynchronous and - * host activity: automatically filled in, - * not intended for direct usage */ - pb_TimerID_LAST /* Number of timer IDs */ + pb_TimerID_IO, /* Time spent in input/output */ + pb_TimerID_KERNEL, /* Time spent computing on the device, + * recorded asynchronously */ + pb_TimerID_COPY, /* Time spent synchronously moving data + * to/from device and allocating/freeing + * memory on the device */ + pb_TimerID_DRIVER, /* Time spent in the host interacting with the + * driver, primarily for recording the time + * spent queueing asynchronous operations */ + pb_TimerID_COPY_ASYNC, /* Time spent in asynchronous transfers */ + pb_TimerID_COMPUTE, /* Time for all program execution other + * than parsing command line arguments, + * I/O, kernel, and copy */ + pb_TimerID_OVERLAP, /* Time double-counted in asynchronous and + * host activity: automatically filled in, + * not intended for direct usage */ + pb_TimerID_LAST /* Number of timer IDs */ }; /* Dynamic list of asynchronously tracked times between events */ struct pb_async_time_marker_list { - char *label; // actually just a pointer to a string - enum pb_TimerID timerID; /* The ID to which the interval beginning - * with this marker should be attributed */ - void * marker; - //cudaEvent_t marker; /* The driver event for this marker */ - struct pb_async_time_marker_list *next; + char *label; // actually just a pointer to a string + enum pb_TimerID timerID; /* The ID to which the interval beginning + * with this marker should be attributed */ + void *marker; + // cudaEvent_t marker; /* The driver event for this marker */ + struct pb_async_time_marker_list *next; }; struct pb_SubTimer { @@ -135,7 +128,7 @@ struct pb_SubTimerList { /* A set of timers for recording execution times. */ struct pb_TimerSet { enum pb_TimerID current; - struct pb_async_time_marker_list* async_markers; + struct pb_async_time_marker_list *async_markers; pb_Timestamp async_begin; pb_Timestamp wall_begin; struct pb_Timer timers[pb_TimerID_LAST]; @@ -143,34 +136,29 @@ struct pb_TimerSet { }; /* Reset all timers in the set. */ -void -pb_InitializeTimerSet(struct pb_TimerSet *timers); +void pb_InitializeTimerSet(struct pb_TimerSet *timers); -void -pb_AddSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID pb_Category); +void pb_AddSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID pb_Category); /* Select which timer the next interval of time should be accounted * to. The selected timer is started and other timers are stopped. * Using pb_TimerID_NONE stops all timers. */ -void -pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); +void pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); -void -pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID category); +void pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID category); /* Print timer values to standard output. */ -void -pb_PrintTimerSet(struct pb_TimerSet *timers); +void pb_PrintTimerSet(struct pb_TimerSet *timers); /* Release timer resources */ -void -pb_DestroyTimerSet(struct pb_TimerSet * timers); +void pb_DestroyTimerSet(struct pb_TimerSet *timers); -void -pb_SetOpenCL(void *clContextPtr, void *clCommandQueuePtr); +void pb_SetOpenCL(void *clContextPtr, void *clCommandQueuePtr); #ifdef __cplusplus } #endif -#endif //PARBOIL_HEADER +#endif // PARBOIL_HEADER diff --git a/sample_programs/cpp_sample_programs/factorial/compileAndRun.sh b/sample_programs/cpp_sample_programs/factorial/compileAndRun.sh index 6e0c4637..63882ec0 100755 --- a/sample_programs/cpp_sample_programs/factorial/compileAndRun.sh +++ b/sample_programs/cpp_sample_programs/factorial/compileAndRun.sh @@ -4,7 +4,7 @@ rm -rf ./llfi* -fname=$1 +fname="factorial" # Generate Makefile $LLFI_BUILD_ROOT/tools/GenerateMakefile --readable --all -o "$fname.ll" @@ -14,10 +14,9 @@ make $LLFI_BUILD_ROOT/bin/instrument --readable "$fname.ll" # Call the profiling script -shift -$LLFI_BUILD_ROOT/bin/profile ./llfi/"$fname-profiling.exe" $@ +$LLFI_BUILD_ROOT/bin/profile ./llfi/"$fname-profiling.exe" $1 # Inject the faults -$LLFI_BUILD_ROOT/bin/injectfault ./llfi/"$fname-faultinjection.exe" $@ +$LLFI_BUILD_ROOT/bin/injectfault ./llfi/"$fname-faultinjection.exe" $1 echo "Done injecting faults" diff --git a/sample_programs/cpp_sample_programs/factorial/factorial.c b/sample_programs/cpp_sample_programs/factorial/factorial.c index d761f42d..40793ea6 100755 --- a/sample_programs/cpp_sample_programs/factorial/factorial.c +++ b/sample_programs/cpp_sample_programs/factorial/factorial.c @@ -1,16 +1,14 @@ #include #include -int main(argc, argv) -int argc; -char *argv[]; +int main(int argc, char *argv[]) { - int i,fact, n; - n = atoi(argv[1]); - fact = 1; - for(i=1;i<=n;i++) - { - fact = fact * i; - } - printf("%d\n",fact); + int i, fact, n; + n = atoi(argv[1]); + fact = 1; + for (i = 1; i <= n; i++) { + fact = fact * i; + } + printf("%d\n", fact); + return 0; } diff --git a/sample_programs/cpp_sample_programs/sad/image.h b/sample_programs/cpp_sample_programs/sad/image.h index 27fc3e0b..2547dcdb 100644 --- a/sample_programs/cpp_sample_programs/sad/image.h +++ b/sample_programs/cpp_sample_programs/sad/image.h @@ -6,8 +6,7 @@ *cr ***************************************************************************/ -struct image_i16 -{ +struct image_i16 { int width; int height; short *data; @@ -17,7 +16,7 @@ struct image_i16 extern "C" { #endif -struct image_i16 * load_image(char *filename); +struct image_i16 *load_image(char *filename); void free_image(struct image_i16 *); #ifdef __cplusplus diff --git a/sample_programs/cpp_sample_programs/sad/parboil.h b/sample_programs/cpp_sample_programs/sad/parboil.h index 9885c9dc..f1f16283 100644 --- a/sample_programs/cpp_sample_programs/sad/parboil.h +++ b/sample_programs/cpp_sample_programs/sad/parboil.h @@ -10,20 +10,20 @@ extern "C" { /* Command line parameters for benchmarks */ struct pb_Parameters { - char *outFile; /* If not NULL, the raw output of the - * computation should be saved to this - * file. The string is owned. */ - char **inpFiles; /* A NULL-terminated array of strings - * holding the input file(s) for the - * computation. The array and strings - * are owned. */ - int synchronizeGpu; /* Controls behavior of CUDA benchmarks. - * If nonzero, a CUDA runtime - * synchronization call should happen - * after each data transfer to the GPU - * and after each kernel call. This - * is necessary for accurate timing - * measurement. */ + char *outFile; /* If not NULL, the raw output of the + * computation should be saved to this + * file. The string is owned. */ + char **inpFiles; /* A NULL-terminated array of strings + * holding the input file(s) for the + * computation. The array and strings + * are owned. */ + int synchronizeGpu; /* Controls behavior of CUDA benchmarks. + * If nonzero, a CUDA runtime + * synchronization call should happen + * after each data transfer to the GPU + * and after each kernel call. This + * is necessary for accurate timing + * measurement. */ }; /* Read command-line parameters. @@ -35,24 +35,21 @@ struct pb_Parameters { * If there is an error, then an error message is printed on stderr * and NULL is returned. */ -struct pb_Parameters * -pb_ReadParameters(int *_argc, char **argv); +struct pb_Parameters *pb_ReadParameters(int *_argc, char **argv); /* Free an instance of struct pb_Parameters. */ -void -pb_FreeParameters(struct pb_Parameters *p); +void pb_FreeParameters(struct pb_Parameters *p); /* Count the number of input files in a pb_Parameters instance. */ -int -pb_Parameters_CountInputs(struct pb_Parameters *p); +int pb_Parameters_CountInputs(struct pb_Parameters *p); /* A time or duration. */ #if _POSIX_VERSION >= 200112L typedef unsigned long long pb_Timestamp; /* time in microseconds */ #else -# error "Timestamps not implemented" +#error "Timestamps not implemented" #endif enum pb_TimerState { @@ -62,49 +59,45 @@ enum pb_TimerState { struct pb_Timer { enum pb_TimerState state; - pb_Timestamp elapsed; /* Amount of time elapsed so far */ - pb_Timestamp init; /* Beginning of the current time interval, - * if state is RUNNING. Undefined - * otherwise. */ + pb_Timestamp elapsed; /* Amount of time elapsed so far */ + pb_Timestamp init; /* Beginning of the current time interval, + * if state is RUNNING. Undefined + * otherwise. */ }; /* Reset a timer. * Use this to initialize a timer or to clear * its elapsed time. The reset timer is stopped. */ -void -pb_ResetTimer(struct pb_Timer *timer); +void pb_ResetTimer(struct pb_Timer *timer); /* Start a timer. The timer is set to RUNNING mode and * time elapsed while the timer is running is added to * the timer. * The timer should not already be running. */ -void -pb_StartTimer(struct pb_Timer *timer); +void pb_StartTimer(struct pb_Timer *timer); /* Stop a timer. * This stops adding elapsed time to the timer. * The timer should not already be stopped. */ -void -pb_StopTimer(struct pb_Timer *timer); +void pb_StopTimer(struct pb_Timer *timer); /* Get the elapsed time in seconds. */ -double -pb_GetElapsedTime(struct pb_Timer *timer); +double pb_GetElapsedTime(struct pb_Timer *timer); /* Execution time is assigned to one of these categories. */ enum pb_TimerID { pb_TimerID_NONE = 0, - pb_TimerID_IO, /* Time spent in input/output */ - pb_TimerID_GPU, /* Time spent computing on the GPU */ - pb_TimerID_COPY, /* Time spent moving data to/from GPU and - * allocating/freeing memory on the GPU */ - pb_TimerID_COMPUTE, /* Time for all program execution other - * than parsing command line arguments, - * I/O, GPU, and copy */ - pb_TimerID_LAST /* Number of timer IDs */ + pb_TimerID_IO, /* Time spent in input/output */ + pb_TimerID_GPU, /* Time spent computing on the GPU */ + pb_TimerID_COPY, /* Time spent moving data to/from GPU and + * allocating/freeing memory on the GPU */ + pb_TimerID_COMPUTE, /* Time for all program execution other + * than parsing command line arguments, + * I/O, GPU, and copy */ + pb_TimerID_LAST /* Number of timer IDs */ }; /* A set of timers for recording execution times. */ @@ -114,18 +107,15 @@ struct pb_TimerSet { }; /* Reset all timers in the set. */ -void -pb_InitializeTimerSet(struct pb_TimerSet *timers); +void pb_InitializeTimerSet(struct pb_TimerSet *timers); /* Select which timer the next interval of time should be accounted * to. The selected timer is started and other timers are stopped. * Using pb_TimerID_NONE stops all timers. */ -void -pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); +void pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); /* Print timer values to standard output. */ -void -pb_PrintTimerSet(struct pb_TimerSet *timers); +void pb_PrintTimerSet(struct pb_TimerSet *timers); #ifdef __cplusplus } diff --git a/sample_programs/cpp_sample_programs/sad/sad.h b/sample_programs/cpp_sample_programs/sad/sad.h index bfd8017f..c8a247ee 100644 --- a/sample_programs/cpp_sample_programs/sad/sad.h +++ b/sample_programs/cpp_sample_programs/sad/sad.h @@ -10,7 +10,7 @@ #define SEARCH_RANGE 16 /* The total search area is 33 pixels square */ -#define SEARCH_DIMENSION (2*SEARCH_RANGE+1) +#define SEARCH_DIMENSION (2 * SEARCH_RANGE + 1) /* The total number of search positions is 33^2 */ #define MAX_POS 1089 @@ -18,7 +18,7 @@ /* This is padded to a multiple of 4 when allocating memory */ #define MAX_POS_PADDED 1092 -/* VBSME block indices in the SAD array for different +/* VBSME block indices in the SAD array for different * block sizes. The index is computed from the * image size in macroblocks. Block sizes are (height, width): * 1: 16 by 16 pixels, one block per macroblock @@ -30,22 +30,27 @@ * 7: 4 by 4 pixels, 16 blocks per macroblock */ #define SAD_TYPE_1_IX(image_size) 0 -#define SAD_TYPE_2_IX(image_size) ((image_size)*MAX_POS_PADDED) -#define SAD_TYPE_3_IX(image_size) ((image_size)*(3*MAX_POS_PADDED)) -#define SAD_TYPE_4_IX(image_size) ((image_size)*(5*MAX_POS_PADDED)) -#define SAD_TYPE_5_IX(image_size) ((image_size)*(9*MAX_POS_PADDED)) -#define SAD_TYPE_6_IX(image_size) ((image_size)*(17*MAX_POS_PADDED)) -#define SAD_TYPE_7_IX(image_size) ((image_size)*(25*MAX_POS_PADDED)) +#define SAD_TYPE_2_IX(image_size) ((image_size) * MAX_POS_PADDED) +#define SAD_TYPE_3_IX(image_size) ((image_size) * (3 * MAX_POS_PADDED)) +#define SAD_TYPE_4_IX(image_size) ((image_size) * (5 * MAX_POS_PADDED)) +#define SAD_TYPE_5_IX(image_size) ((image_size) * (9 * MAX_POS_PADDED)) +#define SAD_TYPE_6_IX(image_size) ((image_size) * (17 * MAX_POS_PADDED)) +#define SAD_TYPE_7_IX(image_size) ((image_size) * (25 * MAX_POS_PADDED)) -#define SAD_TYPE_IX(n, image_size) \ - ((n == 1) ? SAD_TYPE_1_IX(image_size) : \ - ((n == 2) ? SAD_TYPE_2_IX(image_size) : \ - ((n == 3) ? SAD_TYPE_3_IX(image_size) : \ - ((n == 4) ? SAD_TYPE_4_IX(image_size) : \ - ((n == 5) ? SAD_TYPE_5_IX(image_size) : \ - ((n == 6) ? SAD_TYPE_6_IX(image_size) : \ - (SAD_TYPE_7_IX(image_size) \ - ))))))) +#define SAD_TYPE_IX(n, image_size) \ + ((n == 1) \ + ? SAD_TYPE_1_IX(image_size) \ + : ((n == 2) \ + ? SAD_TYPE_2_IX(image_size) \ + : ((n == 3) \ + ? SAD_TYPE_3_IX(image_size) \ + : ((n == 4) \ + ? SAD_TYPE_4_IX(image_size) \ + : ((n == 5) \ + ? SAD_TYPE_5_IX(image_size) \ + : ((n == 6) \ + ? SAD_TYPE_6_IX(image_size) \ + : (SAD_TYPE_7_IX(image_size)))))))) #define SAD_TYPE_1_CT 1 #define SAD_TYPE_2_CT 2 @@ -55,28 +60,27 @@ #define SAD_TYPE_6_CT 8 #define SAD_TYPE_7_CT 16 -#define SAD_TYPE_CT(n) \ - ((n == 1) ? SAD_TYPE_1_CT : \ - ((n == 2) ? SAD_TYPE_2_CT : \ - ((n == 3) ? SAD_TYPE_3_CT : \ - ((n == 4) ? SAD_TYPE_4_CT : \ - ((n == 5) ? SAD_TYPE_5_CT : \ - ((n == 6) ? SAD_TYPE_6_CT : \ - (SAD_TYPE_7_CT \ - ))))))) +#define SAD_TYPE_CT(n) \ + ((n == 1) \ + ? SAD_TYPE_1_CT \ + : ((n == 2) \ + ? SAD_TYPE_2_CT \ + : ((n == 3) \ + ? SAD_TYPE_3_CT \ + : ((n == 4) \ + ? SAD_TYPE_4_CT \ + : ((n == 5) ? SAD_TYPE_5_CT \ + : ((n == 6) ? SAD_TYPE_6_CT \ + : (SAD_TYPE_7_CT))))))) #ifdef __cplusplus extern "C" { #endif -void sad4_cpu(unsigned short *blk_sad, - unsigned short *frame, - unsigned short *ref, - int mb_width, - int mb_height); +void sad4_cpu(unsigned short *blk_sad, unsigned short *frame, + unsigned short *ref, int mb_width, int mb_height); -void larger_sads(unsigned short *sads, - int mbs); +void larger_sads(unsigned short *sads, int mbs); #ifdef __cplusplus } diff --git a/sample_programs/cpp_sample_programs/sha-MiBench/sha.h b/sample_programs/cpp_sample_programs/sha-MiBench/sha.h old mode 100755 new mode 100644 index 2aa4a5a9..4220b530 --- a/sample_programs/cpp_sample_programs/sha-MiBench/sha.h +++ b/sample_programs/cpp_sample_programs/sha-MiBench/sha.h @@ -9,13 +9,13 @@ typedef unsigned char BYTE; typedef unsigned long LONG; -#define SHA_BLOCKSIZE 64 -#define SHA_DIGESTSIZE 20 +#define SHA_BLOCKSIZE 64 +#define SHA_DIGESTSIZE 20 typedef struct { - LONG digest[5]; /* message digest */ - LONG count_lo, count_hi; /* 64-bit bit count */ - LONG data[16]; /* SHA data buffer */ + LONG digest[5]; /* message digest */ + LONG count_lo, count_hi; /* 64-bit bit count */ + LONG data[16]; /* SHA data buffer */ } SHA_INFO; void sha_init(SHA_INFO *); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/exportInputsAsTensors.py index 1d138feb..ef8a3023 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "bert-base-cased" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/input.c b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/input.c index 9fea6b13..4fd8c0ac 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/inputs/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/inputs/exportInputsAsTensors.py index 27693017..dfa270a5 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/inputs/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/inputs/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "prajjwal1/bert-medium" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/parseLLTFIJsonOp.py index 9449ac49..0187dd30 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-medium/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +39,19 @@ def main(inpSample): model = BertForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -58,8 +68,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -69,7 +79,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -87,7 +97,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/exportInputsAsTensors.py index 1d138feb..ef8a3023 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "bert-base-cased" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/input.c b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/input.c index 9fea6b13..4fd8c0ac 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/inputs/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/inputs/exportInputsAsTensors.py index 270995ab..21a9ae21 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/inputs/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/inputs/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "prajjwal1/bert-mini" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/parseLLTFIJsonOp.py index eccbb81f..89b21c83 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-mini/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +39,19 @@ def main(inpSample): model = BertForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -58,8 +68,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -69,7 +79,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -87,7 +97,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/exportInputsAsTensors.py index 1d138feb..ef8a3023 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "bert-base-cased" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/input.c b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/input.c index 9fea6b13..4fd8c0ac 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/inputs/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/inputs/exportInputsAsTensors.py index dbaede71..a76a8556 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/inputs/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/inputs/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "prajjwal1/bert-small" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/parseLLTFIJsonOp.py index 78e5bbca..c615a7c9 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-small/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +39,19 @@ def main(inpSample): model = BertForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -58,8 +68,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -69,7 +79,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -87,7 +97,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/exportInputsAsTensors.py index 1d138feb..ef8a3023 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "bert-base-cased" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/input.c b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/input.c index 9fea6b13..4fd8c0ac 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/inputs/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/inputs/exportInputsAsTensors.py index ea119053..6d38f703 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/inputs/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/inputs/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "prajjwal1/bert-tiny" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/parseLLTFIJsonOp.py index 20975806..5022fa48 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVariants/bert-tiny/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +39,19 @@ def main(inpSample): model = BertForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -58,8 +68,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -69,7 +79,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -87,7 +97,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/input.c b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/input.c index 9fea6b13..4fd8c0ac 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/inputs/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/inputs/exportInputsAsTensors.py index 1d138feb..ef8a3023 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/inputs/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/inputs/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "bert-base-cased" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/parseLLTFIJsonOp.py index 6efe9f22..d86c607e 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/bert-base-cased/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -31,19 +41,19 @@ def main(inpSample): pdb.set_trace() - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -60,8 +70,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -71,7 +81,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -89,7 +99,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/helper_scripts/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/helper_scripts/exportInputsAsTensors.py index fe6e93f2..abcf832c 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/helper_scripts/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/helper_scripts/exportInputsAsTensors.py @@ -2,26 +2,28 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is .", -"Main symptom of cancer is .", -"He is having a cough and fever, thus, he might be suffering from .", -"He is having sickness and vomiting, thus, he might be having .", -"Since you are having post-traumatic disorder, you should consult a .", -"Paracetamol is used to treat a .", -"The hereditary protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of gene transcripts.", -" is a tumor suppressor gene.", -" is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is .", + "Main symptom of cancer is .", + "He is having a cough and fever, thus, he might be suffering from .", + "He is having sickness and vomiting, thus, he might be having .", + "Since you are having post-traumatic disorder, you should consult a .", + "Paracetamol is used to treat a .", + "The hereditary protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of gene transcripts.", + " is a tumor suppressor gene.", + " is a symptom of diabetes.", +] model_name = "roberta-base" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/input.c b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/input.c index 85d66247..591bdfc8 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/parseLLTFIJsonOp.py index 064e2711..25044003 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BertVsRoberta/roberta-based/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, RobertaForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + RobertaForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is .", -"Main symptom of cancer is .", -"He is having a cough and fever, thus, he might be suffering from .", -"He is having sickness and vomiting, thus, he might be having .", -"Since you are having post-traumatic disorder, you should consult a .", -"Paracetamol is used to treat a .", -"The hereditary protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of gene transcripts.", -" is a tumor suppressor gene.", -" is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is .", + "Main symptom of cancer is .", + "He is having a cough and fever, thus, he might be suffering from .", + "He is having sickness and vomiting, thus, he might be having .", + "Since you are having post-traumatic disorder, you should consult a .", + "Paracetamol is used to treat a .", + "The hereditary protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of gene transcripts.", + " is a tumor suppressor gene.", + " is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +39,19 @@ def main(inpSample): model = RobertaForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -58,8 +68,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -69,7 +79,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -87,7 +97,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/compile.sh b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/helper_scripts/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/helper_scripts/exportInputsAsTensors.py index 8da44c21..8104ee01 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/helper_scripts/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/helper_scripts/exportInputsAsTensors.py @@ -2,27 +2,29 @@ import os from onnx import numpy_helper -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_token_type_ids = numpy_helper.from_array(inputs_np['token_type_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_token_type_ids = numpy_helper.from_array(inputs_np["token_type_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/input.c b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/input.c index 9fea6b13..4fd8c0ac 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/parseLLTFIJsonOp.py index 1804b756..91667a4d 100644 --- a/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/BioMedNLP/parseLLTFIJsonOp.py @@ -1,24 +1,34 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["Main symptom of common flu is [MASK].", -"Main symptom of cancer is [MASK].", -"He is having a cough and fever, thus, he might be suffering from [MASK].", -"He is having sickness and vomiting, thus, he might be having [MASK].", -"Since you are having post-traumatic disorder, you should consult a [MASK].", -"Paracetamol is used to treat a [MASK].", -"The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", -"Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", -"[MASK] is a tumor suppressor gene.", -"[MASK] is a symptom of diabetes."] +inputs = [ + "Main symptom of common flu is [MASK].", + "Main symptom of cancer is [MASK].", + "He is having a cough and fever, thus, he might be suffering from [MASK].", + "He is having sickness and vomiting, thus, he might be having [MASK].", + "Since you are having post-traumatic disorder, you should consult a [MASK].", + "Paracetamol is used to treat a [MASK].", + "The hereditary [MASK] protein, HFE, specifically regulates transferrin-mediated iron uptake in HeLa cells.", + "Pelizaeus-Merzbacher disease is caused by overexpression of [MASK] gene transcripts.", + "[MASK] is a tumor suppressor gene.", + "[MASK] is a symptom of diabetes.", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +39,19 @@ def main(inpSample): model = BertForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -58,8 +68,8 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) listResArr.append(resforSingleInput) listShareApp.append(shapeForSingleInput) @@ -69,7 +79,7 @@ def main(inpSample): modelOp = listResArr[i][0] modelOpShape = listShareApp[i][0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -87,7 +97,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/CodeBert/compile.sh b/sample_programs/ml_sample_programs/nlp_models/CodeBert/compile.sh index 765f2b4b..99655e86 100755 --- a/sample_programs/ml_sample_programs/nlp_models/CodeBert/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/CodeBert/compile.sh @@ -2,7 +2,7 @@ printf "\n[Compile Script]: Getting the ONNX model\n" FILE=model.onnx printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp model.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp model.onnx mlir-translate -mlir-to-llvmir model.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/CodeBert/helper_scripts/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/CodeBert/helper_scripts/exportInputsAsTensors.py index db8e1f12..add990cc 100644 --- a/sample_programs/ml_sample_programs/nlp_models/CodeBert/helper_scripts/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/CodeBert/helper_scripts/exportInputsAsTensors.py @@ -2,26 +2,28 @@ import os from onnx import numpy_helper -inputs = ["if (x is not None) (x>1)", -"public String getSecretKey ( String subdomain ) { String secretKey = secretKeys . get ( subdomain ) ; if (secretKey == ) { secretKey = defaultSecretKey ; } return secretKey ; }", -"protected Iterator < Map . Entry < K , V > > createEntrySetIterator ( ) { if ( size ( ) 0 ) { return EmptyIterator . INSTANCE ; } return new EntrySetIterator < K , V > ( this ) ; }", -"public static boolean isPermanent ( ResourceModel resourceModel ) { Object resource = resourceModel . getResource ( ) ; try { return ( Boolean ) ; } catch ( ClassCastException e ) { return false ; } }", -"public void clear ( ) { modCount ++ ; HashEntry [ ] data = this . data ; for ( int i = data . length - 1 ; i 0 ; i -- ) { data [ i ] = null ; } size = 0 ; }", -"public void addPrincipal ( String principal ) { if ( ! readOnly ! principals . contains ( principal ) ) { principals . add ( principal ) ; principalsModified = true ; } }", -"protected final void addAllApplications ( Set < Class < ? > > set ) { for ( Class < ? > cls : set ) { if ( ! cls . isInterface ( ) && ! Modifier . isAbstract ( cls . getModifiers ( ) ) ) { if ( ! this . classmap . ( cls ) ) { this . classNames . add ( cls . getName ( ) ) ; } } } }", -"public void setName ( String name ) { if ( name != null && name . ( this . name ) ) { return ; } this . name = name ; Roster packet = new Roster ( ) ; packet . setType ( IQ . Type . set ) ; packet . addItem ( new JID ( user ) , name , ask , subscription , getGroupNames ( ) ) ; connection . sendPacket ( packet ) ; }", -"""public String getString ( String defaultValue ) { if ( value instanceof String || value instanceof Number ) { return value . toString ( ) ; } if ( value == null ) { return null ; } if ( value instanceof JSONArray ) { return ( ( JSONArray ) value ) . toJSONString ( ) ; } if ( value instanceof JSONObject ) { return ( ( JSONObject ) value ) . ( ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected string:" ) ; }""", -""" public Double getDouble ( Double defaultValue ) { if ( value instanceof Number ) { return ( ( Number ) value ) . doubleValue ( ) ; } if ( value instanceof String ) { String s = ( String ) value ; return Double . ( s ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected number:" ) ; }"""] +inputs = [ + "if (x is not None) (x>1)", + "public String getSecretKey ( String subdomain ) { String secretKey = secretKeys . get ( subdomain ) ; if (secretKey == ) { secretKey = defaultSecretKey ; } return secretKey ; }", + "protected Iterator < Map . Entry < K , V > > createEntrySetIterator ( ) { if ( size ( ) 0 ) { return EmptyIterator . INSTANCE ; } return new EntrySetIterator < K , V > ( this ) ; }", + "public static boolean isPermanent ( ResourceModel resourceModel ) { Object resource = resourceModel . getResource ( ) ; try { return ( Boolean ) ; } catch ( ClassCastException e ) { return false ; } }", + "public void clear ( ) { modCount ++ ; HashEntry [ ] data = this . data ; for ( int i = data . length - 1 ; i 0 ; i -- ) { data [ i ] = null ; } size = 0 ; }", + "public void addPrincipal ( String principal ) { if ( ! readOnly ! principals . contains ( principal ) ) { principals . add ( principal ) ; principalsModified = true ; } }", + "protected final void addAllApplications ( Set < Class < ? > > set ) { for ( Class < ? > cls : set ) { if ( ! cls . isInterface ( ) && ! Modifier . isAbstract ( cls . getModifiers ( ) ) ) { if ( ! this . classmap . ( cls ) ) { this . classNames . add ( cls . getName ( ) ) ; } } } }", + "public void setName ( String name ) { if ( name != null && name . ( this . name ) ) { return ; } this . name = name ; Roster packet = new Roster ( ) ; packet . setType ( IQ . Type . set ) ; packet . addItem ( new JID ( user ) , name , ask , subscription , getGroupNames ( ) ) ; connection . sendPacket ( packet ) ; }", + """public String getString ( String defaultValue ) { if ( value instanceof String || value instanceof Number ) { return value . toString ( ) ; } if ( value == null ) { return null ; } if ( value instanceof JSONArray ) { return ( ( JSONArray ) value ) . toJSONString ( ) ; } if ( value instanceof JSONObject ) { return ( ( JSONObject ) value ) . ( ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected string:" ) ; }""", + """ public Double getDouble ( Double defaultValue ) { if ( value instanceof Number ) { return ( ( Number ) value ) . doubleValue ( ) ; } if ( value instanceof String ) { String s = ( String ) value ; return Double . ( s ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected number:" ) ; }""", +] model_name = "microsoft/codebert-base-mlm" tokenizer = AutoTokenizer.from_pretrained(model_name) for i in range(0, len(inputs)): input_txt = inputs[i] - sequence = (input_txt) + sequence = input_txt inputs_np = tokenizer(sequence, return_tensors="np") - tensor_input_ids = numpy_helper.from_array(inputs_np['input_ids']) - tensor_attention_mask = numpy_helper.from_array(inputs_np['attention_mask']) + tensor_input_ids = numpy_helper.from_array(inputs_np["input_ids"]) + tensor_attention_mask = numpy_helper.from_array(inputs_np["attention_mask"]) with open(os.path.join("", f"input{i}_0.pb"), "wb") as f: f.write(tensor_input_ids.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/CodeBert/input.c b/sample_programs/ml_sample_programs/nlp_models/CodeBert/input.c index 85d66247..591bdfc8 100644 --- a/sample_programs/ml_sample_programs/nlp_models/CodeBert/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/CodeBert/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/CodeBert/parseLLTFIJsonOp.py b/sample_programs/ml_sample_programs/nlp_models/CodeBert/parseLLTFIJsonOp.py index a4efc881..23944e72 100644 --- a/sample_programs/ml_sample_programs/nlp_models/CodeBert/parseLLTFIJsonOp.py +++ b/sample_programs/ml_sample_programs/nlp_models/CodeBert/parseLLTFIJsonOp.py @@ -1,24 +1,36 @@ import tensorflow as tf from transformers import TFAutoModelForMaskedLM, AutoTokenizer -from transformers import AutoTokenizer, AutoModel, AutoModelWithLMHead, BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM +from transformers import ( + AutoTokenizer, + AutoModel, + AutoModelWithLMHead, + BertTokenizer, + BertForMaskedLM, + RobertaTokenizer, + RobertaForMaskedLM, +) import os, glob, json, pdb, sys from onnx import numpy_helper from onnxruntime import InferenceSession import numpy as np -inputs = ["if (x is not None) (x>1)", -"public String getSecretKey ( String subdomain ) { String secretKey = secretKeys . get ( subdomain ) ; if (secretKey == ) { secretKey = defaultSecretKey ; } return secretKey ; }", -"protected Iterator < Map . Entry < K , V > > createEntrySetIterator ( ) { if ( size ( ) 0 ) { return EmptyIterator . INSTANCE ; } return new EntrySetIterator < K , V > ( this ) ; }", -"public static boolean isPermanent ( ResourceModel resourceModel ) { Object resource = resourceModel . getResource ( ) ; try { return ( Boolean ) ; } catch ( ClassCastException e ) { return false ; } }", -"public void clear ( ) { modCount ++ ; HashEntry [ ] data = this . data ; for ( int i = data . length - 1 ; i 0 ; i -- ) { data [ i ] = null ; } size = 0 ; }", -"public void addPrincipal ( String principal ) { if ( ! readOnly ! principals . contains ( principal ) ) { principals . add ( principal ) ; principalsModified = true ; } }", -"protected final void addAllApplications ( Set < Class < ? > > set ) { for ( Class < ? > cls : set ) { if ( ! cls . isInterface ( ) && ! Modifier . isAbstract ( cls . getModifiers ( ) ) ) { if ( ! this . classmap . ( cls ) ) { this . classNames . add ( cls . getName ( ) ) ; } } } }", -"public void setName ( String name ) { if ( name != null && name . ( this . name ) ) { return ; } this . name = name ; Roster packet = new Roster ( ) ; packet . setType ( IQ . Type . set ) ; packet . addItem ( new JID ( user ) , name , ask , subscription , getGroupNames ( ) ) ; connection . sendPacket ( packet ) ; }", -"""public String getString ( String defaultValue ) { if ( value instanceof String || value instanceof Number ) { return value . toString ( ) ; } if ( value == null ) { return null ; } if ( value instanceof JSONArray ) { return ( ( JSONArray ) value ) . toJSONString ( ) ; } if ( value instanceof JSONObject ) { return ( ( JSONObject ) value ) . ( ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected string:" ) ; }""", -""" public Double getDouble ( Double defaultValue ) { if ( value instanceof Number ) { return ( ( Number ) value ) . doubleValue ( ) ; } if ( value instanceof String ) { String s = ( String ) value ; return Double . ( s ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected number:" ) ; }"""] +inputs = [ + "if (x is not None) (x>1)", + "public String getSecretKey ( String subdomain ) { String secretKey = secretKeys . get ( subdomain ) ; if (secretKey == ) { secretKey = defaultSecretKey ; } return secretKey ; }", + "protected Iterator < Map . Entry < K , V > > createEntrySetIterator ( ) { if ( size ( ) 0 ) { return EmptyIterator . INSTANCE ; } return new EntrySetIterator < K , V > ( this ) ; }", + "public static boolean isPermanent ( ResourceModel resourceModel ) { Object resource = resourceModel . getResource ( ) ; try { return ( Boolean ) ; } catch ( ClassCastException e ) { return false ; } }", + "public void clear ( ) { modCount ++ ; HashEntry [ ] data = this . data ; for ( int i = data . length - 1 ; i 0 ; i -- ) { data [ i ] = null ; } size = 0 ; }", + "public void addPrincipal ( String principal ) { if ( ! readOnly ! principals . contains ( principal ) ) { principals . add ( principal ) ; principalsModified = true ; } }", + "protected final void addAllApplications ( Set < Class < ? > > set ) { for ( Class < ? > cls : set ) { if ( ! cls . isInterface ( ) && ! Modifier . isAbstract ( cls . getModifiers ( ) ) ) { if ( ! this . classmap . ( cls ) ) { this . classNames . add ( cls . getName ( ) ) ; } } } }", + "public void setName ( String name ) { if ( name != null && name . ( this . name ) ) { return ; } this . name = name ; Roster packet = new Roster ( ) ; packet . setType ( IQ . Type . set ) ; packet . addItem ( new JID ( user ) , name , ask , subscription , getGroupNames ( ) ) ; connection . sendPacket ( packet ) ; }", + """public String getString ( String defaultValue ) { if ( value instanceof String || value instanceof Number ) { return value . toString ( ) ; } if ( value == null ) { return null ; } if ( value instanceof JSONArray ) { return ( ( JSONArray ) value ) . toJSONString ( ) ; } if ( value instanceof JSONObject ) { return ( ( JSONObject ) value ) . ( ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected string:" ) ; }""", + """ public Double getDouble ( Double defaultValue ) { if ( value instanceof Number ) { return ( ( Number ) value ) . doubleValue ( ) ; } if ( value instanceof String ) { String s = ( String ) value ; return Double . ( s ) ; } if ( value == null ) { return defaultValue ; } throw createException ( "Expected number:" ) ; }""", +] + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): @@ -29,19 +41,19 @@ def main(inpSample): model = RobertaForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - sequence = (inputs[inpSample]) + sequence = inputs[inpSample] inputs = tokenizer(sequence, return_tensors="pt") inputs_np = tokenizer(sequence, return_tensors="np") inputs_tf = tokenizer(sequence, return_tensors="tf") mask_token_index = tf.where(inputs_tf["input_ids"] == tokenizer.mask_token_id)[0, 1] - #Path to LLTFI layer output + # Path to LLTFI layer output ROOT = os.getcwd() - LLFI_OUT = os.path.join(ROOT, 'llfi') - PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') - OUT = os.path.join(ROOT, 'out') - pathOutput = os.path.join(OUT, 'onnx-pred') - filePred = os.path.join(pathOutput, 'onnx-pred.txt') + LLFI_OUT = os.path.join(ROOT, "llfi") + PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + OUT = os.path.join(ROOT, "out") + pathOutput = os.path.join(OUT, "onnx-pred") + filePred = os.path.join(pathOutput, "onnx-pred.txt") # Read LLTFI output from llfi/prog_output and add it to 'listResArr' txtfiles = [] @@ -60,14 +72,14 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - shapeForSingleInput.append(value['Shape']) + resforSingleInput.append(value["Data"]) + shapeForSingleInput.append(value["Shape"]) outputs += f"\n-------- Run: {i} -----------\n" modelOp = resforSingleInput[0] modelOpShape = shapeForSingleInput[0] - npArr = np.array(modelOp) + npArr = np.array(modelOp) npArr = np.reshape(npArr, modelOpShape[1:]) maskedTokenLogits = npArr[mask_token_index.numpy()] top_5_tokens = tf.math.top_k(maskedTokenLogits, 5).indices.numpy() @@ -86,7 +98,6 @@ def main(inpSample): write_file.write(outputs) - if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/__init__.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/__init__.py index effb57b1..deb59588 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/__init__.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/create_pretraining_data.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/create_pretraining_data.py index aca98ff7..5c66d684 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/create_pretraining_data.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/create_pretraining_data.py @@ -27,443 +27,508 @@ FLAGS = flags.FLAGS -flags.DEFINE_string("input_file", None, - "Input raw text file (or comma-separated list of files).") +flags.DEFINE_string( + "input_file", None, "Input raw text file (or comma-separated list of files)." +) flags.DEFINE_string( - "output_file", None, - "Output TF example file (or comma-separated list of files).") + "output_file", None, "Output TF example file (or comma-separated list of files)." +) -flags.DEFINE_string("vocab_file", None, - "The vocabulary file that the BERT model was trained on.") +flags.DEFINE_string( + "vocab_file", None, "The vocabulary file that the BERT model was trained on." +) flags.DEFINE_bool( - "do_lower_case", True, + "do_lower_case", + True, "Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.") + "models and False for cased models.", +) flags.DEFINE_bool( - "do_whole_word_mask", False, - "Whether to use whole word masking rather than per-WordPiece masking.") + "do_whole_word_mask", + False, + "Whether to use whole word masking rather than per-WordPiece masking.", +) flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") -flags.DEFINE_integer("max_predictions_per_seq", 20, - "Maximum number of masked LM predictions per sequence.") +flags.DEFINE_integer( + "max_predictions_per_seq", + 20, + "Maximum number of masked LM predictions per sequence.", +) flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") flags.DEFINE_integer( - "dupe_factor", 10, - "Number of times to duplicate the input data (with different masks).") + "dupe_factor", + 10, + "Number of times to duplicate the input data (with different masks).", +) flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") flags.DEFINE_float( - "short_seq_prob", 0.1, - "Probability of creating sequences which are shorter than the " - "maximum length.") + "short_seq_prob", + 0.1, + "Probability of creating sequences which are shorter than the " "maximum length.", +) class TrainingInstance(object): - """A single training instance (sentence pair).""" - - def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, - is_random_next): - self.tokens = tokens - self.segment_ids = segment_ids - self.is_random_next = is_random_next - self.masked_lm_positions = masked_lm_positions - self.masked_lm_labels = masked_lm_labels - - def __str__(self): - s = "" - s += "tokens: %s\n" % (" ".join( - [tokenization.printable_text(x) for x in self.tokens])) - s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) - s += "is_random_next: %s\n" % self.is_random_next - s += "masked_lm_positions: %s\n" % (" ".join( - [str(x) for x in self.masked_lm_positions])) - s += "masked_lm_labels: %s\n" % (" ".join( - [tokenization.printable_text(x) for x in self.masked_lm_labels])) - s += "\n" - return s - - def __repr__(self): - return self.__str__() - - -def write_instance_to_example_files(instances, tokenizer, max_seq_length, - max_predictions_per_seq, output_files): - """Create TF example files from `TrainingInstance`s.""" - writers = [] - for output_file in output_files: - writers.append(tf.io.TFRecordWriter(output_file)) - - writer_index = 0 - - total_written = 0 - for (inst_index, instance) in enumerate(instances): - input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) - input_mask = [1] * len(input_ids) - segment_ids = list(instance.segment_ids) - assert len(input_ids) <= max_seq_length - - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) - - assert len(input_ids) == max_seq_length - assert len(input_mask) == max_seq_length - assert len(segment_ids) == max_seq_length - - masked_lm_positions = list(instance.masked_lm_positions) - masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) - masked_lm_weights = [1.0] * len(masked_lm_ids) - - while len(masked_lm_positions) < max_predictions_per_seq: - masked_lm_positions.append(0) - masked_lm_ids.append(0) - masked_lm_weights.append(0.0) - - next_sentence_label = 1 if instance.is_random_next else 0 - - features = collections.OrderedDict() - features["input_ids"] = create_int_feature(input_ids) - features["input_mask"] = create_int_feature(input_mask) - features["segment_ids"] = create_int_feature(segment_ids) - features["masked_lm_positions"] = create_int_feature(masked_lm_positions) - features["masked_lm_ids"] = create_int_feature(masked_lm_ids) - features["masked_lm_weights"] = create_float_feature(masked_lm_weights) - features["next_sentence_labels"] = create_int_feature([next_sentence_label]) - - tf_example = tf.train.Example(features=tf.train.Features(feature=features)) - - writers[writer_index].write(tf_example.SerializeToString()) - writer_index = (writer_index + 1) % len(writers) - - total_written += 1 - - if inst_index < 20: - tf.compat.v1.logging.info("*** Example ***") - tf.compat.v1.logging.info("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in instance.tokens])) - - for feature_name in features.keys(): - feature = features[feature_name] - values = [] - if feature.int64_list.value: - values = feature.int64_list.value - elif feature.float_list.value: - values = feature.float_list.value - tf.compat.v1.logging.info( - "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) - - for writer in writers: - writer.close() - - tf.compat.v1.logging.info("Wrote %d total instances", total_written) + """A single training instance (sentence pair).""" + + def __init__( + self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next + ): + self.tokens = tokens + self.segment_ids = segment_ids + self.is_random_next = is_random_next + self.masked_lm_positions = masked_lm_positions + self.masked_lm_labels = masked_lm_labels + + def __str__(self): + s = "" + s += "tokens: %s\n" % ( + " ".join([tokenization.printable_text(x) for x in self.tokens]) + ) + s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) + s += "is_random_next: %s\n" % self.is_random_next + s += "masked_lm_positions: %s\n" % ( + " ".join([str(x) for x in self.masked_lm_positions]) + ) + s += "masked_lm_labels: %s\n" % ( + " ".join([tokenization.printable_text(x) for x in self.masked_lm_labels]) + ) + s += "\n" + return s + + def __repr__(self): + return self.__str__() + + +def write_instance_to_example_files( + instances, tokenizer, max_seq_length, max_predictions_per_seq, output_files +): + """Create TF example files from `TrainingInstance`s.""" + writers = [] + for output_file in output_files: + writers.append(tf.io.TFRecordWriter(output_file)) + + writer_index = 0 + + total_written = 0 + for inst_index, instance in enumerate(instances): + input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) + input_mask = [1] * len(input_ids) + segment_ids = list(instance.segment_ids) + assert len(input_ids) <= max_seq_length + + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + masked_lm_positions = list(instance.masked_lm_positions) + masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) + masked_lm_weights = [1.0] * len(masked_lm_ids) + + while len(masked_lm_positions) < max_predictions_per_seq: + masked_lm_positions.append(0) + masked_lm_ids.append(0) + masked_lm_weights.append(0.0) + + next_sentence_label = 1 if instance.is_random_next else 0 + + features = collections.OrderedDict() + features["input_ids"] = create_int_feature(input_ids) + features["input_mask"] = create_int_feature(input_mask) + features["segment_ids"] = create_int_feature(segment_ids) + features["masked_lm_positions"] = create_int_feature(masked_lm_positions) + features["masked_lm_ids"] = create_int_feature(masked_lm_ids) + features["masked_lm_weights"] = create_float_feature(masked_lm_weights) + features["next_sentence_labels"] = create_int_feature([next_sentence_label]) + + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + + writers[writer_index].write(tf_example.SerializeToString()) + writer_index = (writer_index + 1) % len(writers) + + total_written += 1 + + if inst_index < 20: + tf.compat.v1.logging.info("*** Example ***") + tf.compat.v1.logging.info( + "tokens: %s" + % " ".join([tokenization.printable_text(x) for x in instance.tokens]) + ) + + for feature_name in features.keys(): + feature = features[feature_name] + values = [] + if feature.int64_list.value: + values = feature.int64_list.value + elif feature.float_list.value: + values = feature.float_list.value + tf.compat.v1.logging.info( + "%s: %s" % (feature_name, " ".join([str(x) for x in values])) + ) + + for writer in writers: + writer.close() + + tf.compat.v1.logging.info("Wrote %d total instances", total_written) def create_int_feature(values): - feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) - return feature + feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) + return feature def create_float_feature(values): - feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) - return feature - - -def create_training_instances(input_files, tokenizer, max_seq_length, - dupe_factor, short_seq_prob, masked_lm_prob, - max_predictions_per_seq, rng): - """Create `TrainingInstance`s from raw text.""" - all_documents = [[]] - - # Input file format: - # (1) One sentence per line. These should ideally be actual sentences, not - # entire paragraphs or arbitrary spans of text. (Because we use the - # sentence boundaries for the "next sentence prediction" task). - # (2) Blank lines between documents. Document boundaries are needed so - # that the "next sentence prediction" task doesn't span between documents. - for input_file in input_files: - with tf.io.gfile.GFile(input_file, "r") as reader: - while True: - line = tokenization.convert_to_unicode(reader.readline()) - if not line: - break - line = line.strip() - - # Empty lines are used as document delimiters - if not line: - all_documents.append([]) - tokens = tokenizer.tokenize(line) - if tokens: - all_documents[-1].append(tokens) - - # Remove empty documents - all_documents = [x for x in all_documents if x] - rng.shuffle(all_documents) - - vocab_words = list(tokenizer.vocab.keys()) - instances = [] - for _ in range(dupe_factor): - for document_index in range(len(all_documents)): - instances.extend( - create_instances_from_document( - all_documents, document_index, max_seq_length, short_seq_prob, - masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) - - rng.shuffle(instances) - return instances + feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) + return feature + + +def create_training_instances( + input_files, + tokenizer, + max_seq_length, + dupe_factor, + short_seq_prob, + masked_lm_prob, + max_predictions_per_seq, + rng, +): + """Create `TrainingInstance`s from raw text.""" + all_documents = [[]] + + # Input file format: + # (1) One sentence per line. These should ideally be actual sentences, not + # entire paragraphs or arbitrary spans of text. (Because we use the + # sentence boundaries for the "next sentence prediction" task). + # (2) Blank lines between documents. Document boundaries are needed so + # that the "next sentence prediction" task doesn't span between documents. + for input_file in input_files: + with tf.io.gfile.GFile(input_file, "r") as reader: + while True: + line = tokenization.convert_to_unicode(reader.readline()) + if not line: + break + line = line.strip() + + # Empty lines are used as document delimiters + if not line: + all_documents.append([]) + tokens = tokenizer.tokenize(line) + if tokens: + all_documents[-1].append(tokens) + + # Remove empty documents + all_documents = [x for x in all_documents if x] + rng.shuffle(all_documents) + + vocab_words = list(tokenizer.vocab.keys()) + instances = [] + for _ in range(dupe_factor): + for document_index in range(len(all_documents)): + instances.extend( + create_instances_from_document( + all_documents, + document_index, + max_seq_length, + short_seq_prob, + masked_lm_prob, + max_predictions_per_seq, + vocab_words, + rng, + ) + ) + + rng.shuffle(instances) + return instances def create_instances_from_document( - all_documents, document_index, max_seq_length, short_seq_prob, - masked_lm_prob, max_predictions_per_seq, vocab_words, rng): - """Creates `TrainingInstance`s for a single document.""" - document = all_documents[document_index] - - # Account for [CLS], [SEP], [SEP] - max_num_tokens = max_seq_length - 3 - - # We *usually* want to fill up the entire sequence since we are padding - # to `max_seq_length` anyways, so short sequences are generally wasted - # computation. However, we *sometimes* - # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter - # sequences to minimize the mismatch between pre-training and fine-tuning. - # The `target_seq_length` is just a rough target however, whereas - # `max_seq_length` is a hard limit. - target_seq_length = max_num_tokens - if rng.random() < short_seq_prob: - target_seq_length = rng.randint(2, max_num_tokens) - - # We DON'T just concatenate all of the tokens from a document into a long - # sequence and choose an arbitrary split point because this would make the - # next sentence prediction task too easy. Instead, we split the input into - # segments "A" and "B" based on the actual "sentences" provided by the user - # input. - instances = [] - current_chunk = [] - current_length = 0 - i = 0 - while i < len(document): - segment = document[i] - current_chunk.append(segment) - current_length += len(segment) - if i == len(document) - 1 or current_length >= target_seq_length: - if current_chunk: - # `a_end` is how many segments from `current_chunk` go into the `A` - # (first) sentence. - a_end = 1 - if len(current_chunk) >= 2: - a_end = rng.randint(1, len(current_chunk) - 1) - - tokens_a = [] - for j in range(a_end): - tokens_a.extend(current_chunk[j]) - - tokens_b = [] - # Random next - is_random_next = False - if len(current_chunk) == 1 or rng.random() < 0.5: - is_random_next = True - target_b_length = target_seq_length - len(tokens_a) - - # This should rarely go for more than one iteration for large - # corpora. However, just to be careful, we try to make sure that - # the random document is not the same as the document - # we're processing. - for _ in range(10): - random_document_index = rng.randint(0, len(all_documents) - 1) - if random_document_index != document_index: - break - - random_document = all_documents[random_document_index] - random_start = rng.randint(0, len(random_document) - 1) - for j in range(random_start, len(random_document)): - tokens_b.extend(random_document[j]) - if len(tokens_b) >= target_b_length: - break - # We didn't actually use these segments so we "put them back" so - # they don't go to waste. - num_unused_segments = len(current_chunk) - a_end - i -= num_unused_segments - # Actual next - else: - is_random_next = False - for j in range(a_end, len(current_chunk)): - tokens_b.extend(current_chunk[j]) - truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) - - assert len(tokens_a) >= 1 - assert len(tokens_b) >= 1 - - tokens = [] - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in tokens_a: - tokens.append(token) - segment_ids.append(0) - - tokens.append("[SEP]") - segment_ids.append(0) - - for token in tokens_b: - tokens.append(token) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) - - (tokens, masked_lm_positions, - masked_lm_labels) = create_masked_lm_predictions( - tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) - instance = TrainingInstance( - tokens=tokens, - segment_ids=segment_ids, - is_random_next=is_random_next, - masked_lm_positions=masked_lm_positions, - masked_lm_labels=masked_lm_labels) - instances.append(instance) - current_chunk = [] - current_length = 0 - i += 1 - - return instances - - -MaskedLmInstance = collections.namedtuple("MaskedLmInstance", - ["index", "label"]) - - -def create_masked_lm_predictions(tokens, masked_lm_prob, - max_predictions_per_seq, vocab_words, rng): - """Creates the predictions for the masked LM objective.""" - - cand_indexes = [] - for (i, token) in enumerate(tokens): - if token == "[CLS]" or token == "[SEP]": - continue - # Whole Word Masking means that if we mask all of the wordpieces - # corresponding to an original word. When a word has been split into - # WordPieces, the first token does not have any marker and any subsequence - # tokens are prefixed with ##. So whenever we see the ## token, we - # append it to the previous set of word indexes. - # - # Note that Whole Word Masking does *not* change the training code - # at all -- we still predict each WordPiece independently, softmaxed - # over the entire vocabulary. - if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and - token.startswith("##")): - cand_indexes[-1].append(i) - else: - cand_indexes.append([i]) - - rng.shuffle(cand_indexes) - - output_tokens = list(tokens) - - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) - - masked_lms = [] - covered_indexes = set() - for index_set in cand_indexes: - if len(masked_lms) >= num_to_predict: - break - # If adding a whole-word mask would exceed the maximum number of - # predictions, then just skip this candidate. - if len(masked_lms) + len(index_set) > num_to_predict: - continue - is_any_index_covered = False - for index in index_set: - if index in covered_indexes: - is_any_index_covered = True - break - if is_any_index_covered: - continue - for index in index_set: - covered_indexes.add(index) - - masked_token = None - # 80% of the time, replace with [MASK] - if rng.random() < 0.8: - masked_token = "[MASK]" - else: - # 10% of the time, keep original - if rng.random() < 0.5: - masked_token = tokens[index] - # 10% of the time, replace with random word + all_documents, + document_index, + max_seq_length, + short_seq_prob, + masked_lm_prob, + max_predictions_per_seq, + vocab_words, + rng, +): + """Creates `TrainingInstance`s for a single document.""" + document = all_documents[document_index] + + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + + # We *usually* want to fill up the entire sequence since we are padding + # to `max_seq_length` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `max_seq_length` is a hard limit. + target_seq_length = max_num_tokens + if rng.random() < short_seq_prob: + target_seq_length = rng.randint(2, max_num_tokens) + + # We DON'T just concatenate all of the tokens from a document into a long + # sequence and choose an arbitrary split point because this would make the + # next sentence prediction task too easy. Instead, we split the input into + # segments "A" and "B" based on the actual "sentences" provided by the user + # input. + instances = [] + current_chunk = [] + current_length = 0 + i = 0 + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into the `A` + # (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = rng.randint(1, len(current_chunk) - 1) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + # Random next + is_random_next = False + if len(current_chunk) == 1 or rng.random() < 0.5: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + + # This should rarely go for more than one iteration for large + # corpora. However, just to be careful, we try to make sure that + # the random document is not the same as the document + # we're processing. + for _ in range(10): + random_document_index = rng.randint(0, len(all_documents) - 1) + if random_document_index != document_index: + break + + random_document = all_documents[random_document_index] + random_start = rng.randint(0, len(random_document) - 1) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we "put them back" so + # they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_random_next = False + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens = [] + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + + tokens.append("[SEP]") + segment_ids.append(0) + + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + tokens, masked_lm_positions, masked_lm_labels = ( + create_masked_lm_predictions( + tokens, + masked_lm_prob, + max_predictions_per_seq, + vocab_words, + rng, + ) + ) + instance = TrainingInstance( + tokens=tokens, + segment_ids=segment_ids, + is_random_next=is_random_next, + masked_lm_positions=masked_lm_positions, + masked_lm_labels=masked_lm_labels, + ) + instances.append(instance) + current_chunk = [] + current_length = 0 + i += 1 + + return instances + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) + + +def create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng +): + """Creates the predictions for the masked LM objective.""" + + cand_indexes = [] + for i, token in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if ( + FLAGS.do_whole_word_mask + and len(cand_indexes) >= 1 + and token.startswith("##") + ): + cand_indexes[-1].append(i) else: - masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] - - output_tokens[index] = masked_token - - masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) - assert len(masked_lms) <= num_to_predict - masked_lms = sorted(masked_lms, key=lambda x: x.index) - - masked_lm_positions = [] - masked_lm_labels = [] - for p in masked_lms: - masked_lm_positions.append(p.index) - masked_lm_labels.append(p.label) - - return (output_tokens, masked_lm_positions, masked_lm_labels) + cand_indexes.append([i]) + + rng.shuffle(cand_indexes) + + output_tokens = list(tokens) + + num_to_predict = min( + max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))) + ) + + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if rng.random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] + + output_tokens[index] = masked_token + + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + assert len(masked_lms) <= num_to_predict + masked_lms = sorted(masked_lms, key=lambda x: x.index) + + masked_lm_positions = [] + masked_lm_labels = [] + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + + return (output_tokens, masked_lm_positions, masked_lm_labels) def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): - """Truncates a pair of sequences to a maximum sequence length.""" - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_num_tokens: - break + """Truncates a pair of sequences to a maximum sequence length.""" + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break - trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b - assert len(trunc_tokens) >= 1 + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 - # We want to sometimes truncate from the front and sometimes from the - # back to add more randomness and avoid biases. - if rng.random() < 0.5: - del trunc_tokens[0] - else: - trunc_tokens.pop() + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if rng.random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() def main(_): - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - - tokenizer = tokenization.FullTokenizer( - vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - - input_files = [] - for input_pattern in FLAGS.input_file.split(","): - input_files.extend(tf.io.gfile.glob(input_pattern)) - - tf.compat.v1.logging.info("*** Reading from input files ***") - for input_file in input_files: - tf.compat.v1.logging.info(" %s", input_file) - - rng = random.Random(FLAGS.random_seed) - instances = create_training_instances( - input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, - FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, - rng) - - output_files = FLAGS.output_file.split(",") - tf.compat.v1.logging.info("*** Writing to output files ***") - for output_file in output_files: - tf.compat.v1.logging.info(" %s", output_file) - - write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, - FLAGS.max_predictions_per_seq, output_files) + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case + ) + + input_files = [] + for input_pattern in FLAGS.input_file.split(","): + input_files.extend(tf.io.gfile.glob(input_pattern)) + + tf.compat.v1.logging.info("*** Reading from input files ***") + for input_file in input_files: + tf.compat.v1.logging.info(" %s", input_file) + + rng = random.Random(FLAGS.random_seed) + instances = create_training_instances( + input_files, + tokenizer, + FLAGS.max_seq_length, + FLAGS.dupe_factor, + FLAGS.short_seq_prob, + FLAGS.masked_lm_prob, + FLAGS.max_predictions_per_seq, + rng, + ) + + output_files = FLAGS.output_file.split(",") + tf.compat.v1.logging.info("*** Writing to output files ***") + for output_file in output_files: + tf.compat.v1.logging.info(" %s", output_file) + + write_instance_to_example_files( + instances, + tokenizer, + FLAGS.max_seq_length, + FLAGS.max_predictions_per_seq, + output_files, + ) if __name__ == "__main__": - flags.mark_flag_as_required("input_file") - flags.mark_flag_as_required("output_file") - flags.mark_flag_as_required("vocab_file") - tf.compat.v1.app.run() + flags.mark_flag_as_required("input_file") + flags.mark_flag_as_required("output_file") + flags.mark_flag_as_required("vocab_file") + tf.compat.v1.app.run() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/extract_features.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/extract_features.py index 3d236874..8e6d294b 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/extract_features.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/extract_features.py @@ -38,382 +38,409 @@ flags.DEFINE_string("layers", "-1,-2,-3,-4", "") flags.DEFINE_string( - "bert_config_file", None, + "bert_config_file", + None, "The config json file corresponding to the pre-trained BERT model. " - "This specifies the model architecture.") + "This specifies the model architecture.", +) flags.DEFINE_integer( - "max_seq_length", 128, + "max_seq_length", + 128, "The maximum total input sequence length after WordPiece tokenization. " "Sequences longer than this will be truncated, and sequences shorter " - "than this will be padded.") + "than this will be padded.", +) flags.DEFINE_string( - "init_checkpoint", None, - "Initial checkpoint (usually from a pre-trained BERT model).") + "init_checkpoint", + None, + "Initial checkpoint (usually from a pre-trained BERT model).", +) -flags.DEFINE_string("vocab_file", None, - "The vocabulary file that the BERT model was trained on.") +flags.DEFINE_string( + "vocab_file", None, "The vocabulary file that the BERT model was trained on." +) flags.DEFINE_bool( - "do_lower_case", True, + "do_lower_case", + True, "Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.") + "models and False for cased models.", +) flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") -flags.DEFINE_string("master", None, - "If using a TPU, the address of the master.") +flags.DEFINE_string("master", None, "If using a TPU, the address of the master.") flags.DEFINE_integer( - "num_tpu_cores", 8, - "Only used if `use_tpu` is True. Total number of TPU cores to use.") + "num_tpu_cores", + 8, + "Only used if `use_tpu` is True. Total number of TPU cores to use.", +) flags.DEFINE_bool( - "use_one_hot_embeddings", False, + "use_one_hot_embeddings", + False, "If True, tf.one_hot will be used for embedding lookups, otherwise " "tf.nn.embedding_lookup will be used. On TPUs, this should be True " - "since it is much faster.") + "since it is much faster.", +) class InputExample(object): - def __init__(self, unique_id, text_a, text_b): - self.unique_id = unique_id - self.text_a = text_a - self.text_b = text_b + def __init__(self, unique_id, text_a, text_b): + self.unique_id = unique_id + self.text_a = text_a + self.text_b = text_b class InputFeatures(object): - """A single set of features of data.""" + """A single set of features of data.""" - def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): - self.unique_id = unique_id - self.tokens = tokens - self.input_ids = input_ids - self.input_mask = input_mask - self.input_type_ids = input_type_ids + def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): + self.unique_id = unique_id + self.tokens = tokens + self.input_ids = input_ids + self.input_mask = input_mask + self.input_type_ids = input_type_ids def input_fn_builder(features, seq_length): - """Creates an `input_fn` closure to be passed to TPUEstimator.""" - - all_unique_ids = [] - all_input_ids = [] - all_input_mask = [] - all_input_type_ids = [] - - for feature in features: - all_unique_ids.append(feature.unique_id) - all_input_ids.append(feature.input_ids) - all_input_mask.append(feature.input_mask) - all_input_type_ids.append(feature.input_type_ids) - - def input_fn(params): - """The actual input function.""" - batch_size = params["batch_size"] - - num_examples = len(features) - - # This is for demo purposes and does NOT scale to large data sets. We do - # not use Dataset.from_generator() because that uses tf.py_func which is - # not TPU compatible. The right way to load data is with TFRecordReader. - d = tf.data.Dataset.from_tensor_slices({ - "unique_ids": - tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), - "input_ids": - tf.constant( - all_input_ids, shape=[num_examples, seq_length], - dtype=tf.int32), - "input_mask": - tf.constant( - all_input_mask, - shape=[num_examples, seq_length], - dtype=tf.int32), - "input_type_ids": - tf.constant( - all_input_type_ids, - shape=[num_examples, seq_length], - dtype=tf.int32), - }) - - d = d.batch(batch_size=batch_size, drop_remainder=False) - return d - - return input_fn - - -def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, - use_one_hot_embeddings): - """Returns `model_fn` closure for TPUEstimator.""" - - def model_fn(features, labels, mode, params): # pylint: disable=unused-argument - """The `model_fn` for TPUEstimator.""" - - unique_ids = features["unique_ids"] - input_ids = features["input_ids"] - input_mask = features["input_mask"] - input_type_ids = features["input_type_ids"] - - model = modeling.BertModel( - config=bert_config, - is_training=False, - input_ids=input_ids, - input_mask=input_mask, - token_type_ids=input_type_ids, - use_one_hot_embeddings=use_one_hot_embeddings) - - if mode != tf.estimator.ModeKeys.PREDICT: - raise ValueError("Only PREDICT modes are supported: %s" % (mode)) - - tvars = tf.compat.v1.trainable_variables() - scaffold_fn = None - (assignment_map, - initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( - tvars, init_checkpoint) - if use_tpu: - - def tpu_scaffold(): - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - return tf.compat.v1.train.Scaffold() - - scaffold_fn = tpu_scaffold - else: - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - - tf.compat.v1.logging.info("**** Trainable Variables ****") - for var in tvars: - init_string = "" - if var.name in initialized_variable_names: - init_string = ", *INIT_FROM_CKPT*" - tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, - init_string) - - all_layers = model.get_all_encoder_layers() - - predictions = { - "unique_id": unique_ids, - } - - for (i, layer_index) in enumerate(layer_indexes): - predictions["layer_output_%d" % i] = all_layers[layer_index] - - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) - return output_spec - - return model_fn + """Creates an `input_fn` closure to be passed to TPUEstimator.""" + + all_unique_ids = [] + all_input_ids = [] + all_input_mask = [] + all_input_type_ids = [] + + for feature in features: + all_unique_ids.append(feature.unique_id) + all_input_ids.append(feature.input_ids) + all_input_mask.append(feature.input_mask) + all_input_type_ids.append(feature.input_type_ids) + + def input_fn(params): + """The actual input function.""" + batch_size = params["batch_size"] + + num_examples = len(features) + + # This is for demo purposes and does NOT scale to large data sets. We do + # not use Dataset.from_generator() because that uses tf.py_func which is + # not TPU compatible. The right way to load data is with TFRecordReader. + d = tf.data.Dataset.from_tensor_slices( + { + "unique_ids": tf.constant( + all_unique_ids, shape=[num_examples], dtype=tf.int32 + ), + "input_ids": tf.constant( + all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32 + ), + "input_mask": tf.constant( + all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32 + ), + "input_type_ids": tf.constant( + all_input_type_ids, shape=[num_examples, seq_length], dtype=tf.int32 + ), + } + ) + + d = d.batch(batch_size=batch_size, drop_remainder=False) + return d + + return input_fn + + +def model_fn_builder( + bert_config, init_checkpoint, layer_indexes, use_tpu, use_one_hot_embeddings +): + """Returns `model_fn` closure for TPUEstimator.""" + + def model_fn(features, labels, mode, params): # pylint: disable=unused-argument + """The `model_fn` for TPUEstimator.""" + + unique_ids = features["unique_ids"] + input_ids = features["input_ids"] + input_mask = features["input_mask"] + input_type_ids = features["input_type_ids"] + + model = modeling.BertModel( + config=bert_config, + is_training=False, + input_ids=input_ids, + input_mask=input_mask, + token_type_ids=input_type_ids, + use_one_hot_embeddings=use_one_hot_embeddings, + ) + + if mode != tf.estimator.ModeKeys.PREDICT: + raise ValueError("Only PREDICT modes are supported: %s" % (mode)) + + tvars = tf.compat.v1.trainable_variables() + scaffold_fn = None + assignment_map, initialized_variable_names = ( + modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) + ) + if use_tpu: + + def tpu_scaffold(): + tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) + return tf.compat.v1.train.Scaffold() + + scaffold_fn = tpu_scaffold + else: + tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) + + tf.compat.v1.logging.info("**** Trainable Variables ****") + for var in tvars: + init_string = "" + if var.name in initialized_variable_names: + init_string = ", *INIT_FROM_CKPT*" + tf.compat.v1.logging.info( + " name = %s, shape = %s%s", var.name, var.shape, init_string + ) + + all_layers = model.get_all_encoder_layers() + + predictions = { + "unique_id": unique_ids, + } + + for i, layer_index in enumerate(layer_indexes): + predictions["layer_output_%d" % i] = all_layers[layer_index] + + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, predictions=predictions, scaffold_fn=scaffold_fn + ) + return output_spec + + return model_fn def convert_examples_to_features(examples, seq_length, tokenizer): - """Loads a data file into a list of `InputBatch`s.""" - - features = [] - for (ex_index, example) in enumerate(examples): - tokens_a = tokenizer.tokenize(example.text_a) - - tokens_b = None - if example.text_b: - tokens_b = tokenizer.tokenize(example.text_b) - - if tokens_b: - # Modifies `tokens_a` and `tokens_b` in place so that the total - # length is less than the specified length. - # Account for [CLS], [SEP], [SEP] with "- 3" - _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) - else: - # Account for [CLS] and [SEP] with "- 2" - if len(tokens_a) > seq_length - 2: - tokens_a = tokens_a[0:(seq_length - 2)] - - # The convention in BERT is: - # (a) For sequence pairs: - # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] - # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 - # (b) For single sequences: - # tokens: [CLS] the dog is hairy . [SEP] - # type_ids: 0 0 0 0 0 0 0 - # - # Where "type_ids" are used to indicate whether this is the first - # sequence or the second sequence. The embedding vectors for `type=0` and - # `type=1` were learned during pre-training and are added to the wordpiece - # embedding vector (and position vector). This is not *strictly* necessary - # since the [SEP] token unambiguously separates the sequences, but it makes - # it easier for the model to learn the concept of sequences. - # - # For classification tasks, the first vector (corresponding to [CLS]) is - # used as as the "sentence vector". Note that this only makes sense because - # the entire model is fine-tuned. - tokens = [] - input_type_ids = [] - tokens.append("[CLS]") - input_type_ids.append(0) - for token in tokens_a: - tokens.append(token) - input_type_ids.append(0) - tokens.append("[SEP]") - input_type_ids.append(0) - - if tokens_b: - for token in tokens_b: - tokens.append(token) - input_type_ids.append(1) - tokens.append("[SEP]") - input_type_ids.append(1) - - input_ids = tokenizer.convert_tokens_to_ids(tokens) - - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1] * len(input_ids) - - # Zero-pad up to the sequence length. - while len(input_ids) < seq_length: - input_ids.append(0) - input_mask.append(0) - input_type_ids.append(0) - - assert len(input_ids) == seq_length - assert len(input_mask) == seq_length - assert len(input_type_ids) == seq_length - - if ex_index < 5: - tf.compat.v1.logging.info("*** Example ***") - tf.compat.v1.logging.info("unique_id: %s" % (example.unique_id)) - tf.compat.v1.logging.info("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in tokens])) - tf.compat.v1.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - tf.compat.v1.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - tf.compat.v1.logging.info( - "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) - - features.append( - InputFeatures( - unique_id=example.unique_id, - tokens=tokens, - input_ids=input_ids, - input_mask=input_mask, - input_type_ids=input_type_ids)) - return features + """Loads a data file into a list of `InputBatch`s.""" + + features = [] + for ex_index, example in enumerate(examples): + tokens_a = tokenizer.tokenize(example.text_a) + + tokens_b = None + if example.text_b: + tokens_b = tokenizer.tokenize(example.text_b) + + if tokens_b: + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP] with "- 3" + _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) + else: + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > seq_length - 2: + tokens_a = tokens_a[0 : (seq_length - 2)] + + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + input_type_ids = [] + tokens.append("[CLS]") + input_type_ids.append(0) + for token in tokens_a: + tokens.append(token) + input_type_ids.append(0) + tokens.append("[SEP]") + input_type_ids.append(0) + + if tokens_b: + for token in tokens_b: + tokens.append(token) + input_type_ids.append(1) + tokens.append("[SEP]") + input_type_ids.append(1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + while len(input_ids) < seq_length: + input_ids.append(0) + input_mask.append(0) + input_type_ids.append(0) + + assert len(input_ids) == seq_length + assert len(input_mask) == seq_length + assert len(input_type_ids) == seq_length + + if ex_index < 5: + tf.compat.v1.logging.info("*** Example ***") + tf.compat.v1.logging.info("unique_id: %s" % (example.unique_id)) + tf.compat.v1.logging.info( + "tokens: %s" + % " ".join([tokenization.printable_text(x) for x in tokens]) + ) + tf.compat.v1.logging.info( + "input_ids: %s" % " ".join([str(x) for x in input_ids]) + ) + tf.compat.v1.logging.info( + "input_mask: %s" % " ".join([str(x) for x in input_mask]) + ) + tf.compat.v1.logging.info( + "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]) + ) + + features.append( + InputFeatures( + unique_id=example.unique_id, + tokens=tokens, + input_ids=input_ids, + input_mask=input_mask, + input_type_ids=input_type_ids, + ) + ) + return features def _truncate_seq_pair(tokens_a, tokens_b, max_length): - """Truncates a sequence pair in place to the maximum length.""" - - # This is a simple heuristic which will always truncate the longer sequence - # one token at a time. This makes more sense than truncating an equal percent - # of tokens from each, since if one sequence is very short then each token - # that's truncated likely contains more information than a longer sequence. - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_length: - break - if len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() def read_examples(input_file): - """Read a list of `InputExample`s from an input file.""" - examples = [] - unique_id = 0 - with tf.io.gfile.GFile(input_file, "r") as reader: - while True: - line = tokenization.convert_to_unicode(reader.readline()) - if not line: - break - line = line.strip() - text_a = None - text_b = None - m = re.match(r"^(.*) \|\|\| (.*)$", line) - if m is None: - text_a = line - else: - text_a = m.group(1) - text_b = m.group(2) - examples.append( - InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) - unique_id += 1 - return examples + """Read a list of `InputExample`s from an input file.""" + examples = [] + unique_id = 0 + with tf.io.gfile.GFile(input_file, "r") as reader: + while True: + line = tokenization.convert_to_unicode(reader.readline()) + if not line: + break + line = line.strip() + text_a = None + text_b = None + m = re.match(r"^(.*) \|\|\| (.*)$", line) + if m is None: + text_a = line + else: + text_a = m.group(1) + text_b = m.group(2) + examples.append( + InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b) + ) + unique_id += 1 + return examples def main(_): - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - - layer_indexes = [int(x) for x in FLAGS.layers.split(",")] - - bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) - - tokenizer = tokenization.FullTokenizer( - vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - - is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.compat.v1.estimator.tpu.RunConfig( - master=FLAGS.master, - tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - - examples = read_examples(FLAGS.input_file) - - features = convert_examples_to_features( - examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) - - unique_id_to_feature = {} - for feature in features: - unique_id_to_feature[feature.unique_id] = feature - - model_fn = model_fn_builder( - bert_config=bert_config, - init_checkpoint=FLAGS.init_checkpoint, - layer_indexes=layer_indexes, - use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = tf.compat.v1.estimator.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, - model_fn=model_fn, - config=run_config, - predict_batch_size=FLAGS.batch_size) - - input_fn = input_fn_builder( - features=features, seq_length=FLAGS.max_seq_length) - - with codecs.getwriter("utf-8")(tf.io.gfile.GFile(FLAGS.output_file, - "w")) as writer: - for result in estimator.predict(input_fn, yield_single_examples=True): - unique_id = int(result["unique_id"]) - feature = unique_id_to_feature[unique_id] - output_json = collections.OrderedDict() - output_json["linex_index"] = unique_id - all_features = [] - for (i, token) in enumerate(feature.tokens): - all_layers = [] - for (j, layer_index) in enumerate(layer_indexes): - layer_output = result["layer_output_%d" % j] - layers = collections.OrderedDict() - layers["index"] = layer_index - layers["values"] = [ - round(float(x), 6) for x in layer_output[i:(i + 1)].flat - ] - all_layers.append(layers) - features = collections.OrderedDict() - features["token"] = token - features["layers"] = all_layers - all_features.append(features) - output_json["features"] = all_features - writer.write(json.dumps(output_json) + "\n") + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + layer_indexes = [int(x) for x in FLAGS.layers.split(",")] + + bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) + + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case + ) + + is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 + run_config = tf.compat.v1.estimator.tpu.RunConfig( + master=FLAGS.master, + tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( + num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host + ), + ) + + examples = read_examples(FLAGS.input_file) + + features = convert_examples_to_features( + examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer + ) + + unique_id_to_feature = {} + for feature in features: + unique_id_to_feature[feature.unique_id] = feature + + model_fn = model_fn_builder( + bert_config=bert_config, + init_checkpoint=FLAGS.init_checkpoint, + layer_indexes=layer_indexes, + use_tpu=FLAGS.use_tpu, + use_one_hot_embeddings=FLAGS.use_one_hot_embeddings, + ) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = tf.compat.v1.estimator.tpu.TPUEstimator( + use_tpu=FLAGS.use_tpu, + model_fn=model_fn, + config=run_config, + predict_batch_size=FLAGS.batch_size, + ) + + input_fn = input_fn_builder(features=features, seq_length=FLAGS.max_seq_length) + + with codecs.getwriter("utf-8")(tf.io.gfile.GFile(FLAGS.output_file, "w")) as writer: + for result in estimator.predict(input_fn, yield_single_examples=True): + unique_id = int(result["unique_id"]) + feature = unique_id_to_feature[unique_id] + output_json = collections.OrderedDict() + output_json["linex_index"] = unique_id + all_features = [] + for i, token in enumerate(feature.tokens): + all_layers = [] + for j, layer_index in enumerate(layer_indexes): + layer_output = result["layer_output_%d" % j] + layers = collections.OrderedDict() + layers["index"] = layer_index + layers["values"] = [ + round(float(x), 6) for x in layer_output[i : (i + 1)].flat + ] + all_layers.append(layers) + features = collections.OrderedDict() + features["token"] = token + features["layers"] = all_layers + all_features.append(features) + output_json["features"] = all_features + writer.write(json.dumps(output_json) + "\n") if __name__ == "__main__": - flags.mark_flag_as_required("input_file") - flags.mark_flag_as_required("vocab_file") - flags.mark_flag_as_required("bert_config_file") - flags.mark_flag_as_required("init_checkpoint") - flags.mark_flag_as_required("output_file") - tf.compat.v1.app.run() + flags.mark_flag_as_required("input_file") + flags.mark_flag_as_required("vocab_file") + flags.mark_flag_as_required("bert_config_file") + flags.mark_flag_as_required("init_checkpoint") + flags.mark_flag_as_required("output_file") + tf.compat.v1.app.run() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling.py index 05d7e87e..f7dc6079 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling.py @@ -29,958 +29,999 @@ class BertConfig(object): - """Configuration for `BertModel`.""" - - def __init__(self, - vocab_size, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - initializer_range=0.02): - """Constructs BertConfig. - - Args: - vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. - hidden_dropout_prob: The dropout probability for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The stdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - - @classmethod - def from_dict(cls, json_object): - """Constructs a `BertConfig` from a Python dictionary of parameters.""" - config = BertConfig(vocab_size=None) - for (key, value) in six.iteritems(json_object): - config.__dict__[key] = value - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" - with tf.io.gfile.GFile(json_file, "r") as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + """Configuration for `BertModel`.""" + + def __init__( + self, + vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + ): + """Constructs BertConfig. + + Args: + vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. + hidden_dropout_prob: The dropout probability for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The stdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size=None) + for key, value in six.iteritems(json_object): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with tf.io.gfile.GFile(json_file, "r") as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" class BertModel(object): - """BERT model ("Bidirectional Encoder Representations from Transformers"). - - Example usage: - - ```python - # Already been converted into WordPiece token ids - input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) - input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) - token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) - - config = modeling.BertConfig(vocab_size=32000, hidden_size=512, - num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) - - model = modeling.BertModel(config=config, is_training=True, - input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) - - label_embeddings = tf.get_variable(...) - pooled_output = model.get_pooled_output() - logits = tf.matmul(pooled_output, label_embeddings) - ... - ``` - """ - - def __init__(self, - config, - is_training, - input_ids, - input_mask=None, - token_type_ids=None, - use_one_hot_embeddings=False, - scope=None): - """Constructor for BertModel. + """BERT model ("Bidirectional Encoder Representations from Transformers"). - Args: - config: `BertConfig` instance. - is_training: bool. true for training model, false for eval model. Controls - whether dropout will be applied. - input_ids: int32 Tensor of shape [batch_size, seq_length]. - input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. - token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. - use_one_hot_embeddings: (optional) bool. Whether to use one-hot word - embeddings or tf.embedding_lookup() for the word embeddings. - scope: (optional) variable scope. Defaults to "bert". + Example usage: - Raises: - ValueError: The config is invalid or one of the input tensor shapes - is invalid. - """ - config = copy.deepcopy(config) - if not is_training: - config.hidden_dropout_prob = 0.0 - config.attention_probs_dropout_prob = 0.0 + ```python + # Already been converted into WordPiece token ids + input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) + input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) + token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) - input_shape = get_shape_list(input_ids, expected_rank=2) - batch_size = input_shape[0] - seq_length = input_shape[1] + config = modeling.BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) - if input_mask is None: - input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) - - if token_type_ids is None: - token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) - - with tf.compat.v1.variable_scope(scope, default_name="bert"): - with tf.compat.v1.variable_scope("embeddings"): - # Perform embedding lookup on the word ids. - (self.embedding_output, self.embedding_table) = embedding_lookup( - input_ids=input_ids, - vocab_size=config.vocab_size, - embedding_size=config.hidden_size, - initializer_range=config.initializer_range, - word_embedding_name="word_embeddings", - use_one_hot_embeddings=use_one_hot_embeddings) - - # Add positional embeddings and token type embeddings, then layer - # normalize and perform dropout. - self.embedding_output = embedding_postprocessor( - input_tensor=self.embedding_output, - use_token_type=True, - token_type_ids=token_type_ids, - token_type_vocab_size=config.type_vocab_size, - token_type_embedding_name="token_type_embeddings", - use_position_embeddings=True, - position_embedding_name="position_embeddings", - initializer_range=config.initializer_range, - max_position_embeddings=config.max_position_embeddings, - dropout_prob=config.hidden_dropout_prob) - - with tf.compat.v1.variable_scope("encoder"): - # This converts a 2D mask of shape [batch_size, seq_length] to a 3D - # mask of shape [batch_size, seq_length, seq_length] which is used - # for the attention scores. - attention_mask = create_attention_mask_from_input_mask( - input_ids, input_mask) - - # Run the stacked transformer. - # `sequence_output` shape = [batch_size, seq_length, hidden_size]. - self.all_encoder_layers = transformer_model( - input_tensor=self.embedding_output, - attention_mask=attention_mask, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - intermediate_size=config.intermediate_size, - intermediate_act_fn=get_activation(config.hidden_act), - hidden_dropout_prob=config.hidden_dropout_prob, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - initializer_range=config.initializer_range, - do_return_all_layers=True) - - self.sequence_output = self.all_encoder_layers[-1] - # The "pooler" converts the encoded sequence tensor of shape - # [batch_size, seq_length, hidden_size] to a tensor of shape - # [batch_size, hidden_size]. This is necessary for segment-level - # (or segment-pair-level) classification tasks where we need a fixed - # dimensional representation of the segment. - with tf.compat.v1.variable_scope("pooler"): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. We assume that this has been pre-trained - first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) - self.pooled_output = tf.compat.v1.layers.dense( - first_token_tensor, - config.hidden_size, - activation=tf.tanh, - kernel_initializer=create_initializer(config.initializer_range)) - - def get_pooled_output(self): - return self.pooled_output - - def get_sequence_output(self): - """Gets final hidden layer of encoder. + model = modeling.BertModel(config=config, is_training=True, + input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) - Returns: - float Tensor of shape [batch_size, seq_length, hidden_size] corresponding - to the final hidden of the transformer encoder. + label_embeddings = tf.get_variable(...) + pooled_output = model.get_pooled_output() + logits = tf.matmul(pooled_output, label_embeddings) + ... + ``` """ - return self.sequence_output - def get_all_encoder_layers(self): - return self.all_encoder_layers + def __init__( + self, + config, + is_training, + input_ids, + input_mask=None, + token_type_ids=None, + use_one_hot_embeddings=False, + scope=None, + ): + """Constructor for BertModel. + + Args: + config: `BertConfig` instance. + is_training: bool. true for training model, false for eval model. Controls + whether dropout will be applied. + input_ids: int32 Tensor of shape [batch_size, seq_length]. + input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. + token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. + use_one_hot_embeddings: (optional) bool. Whether to use one-hot word + embeddings or tf.embedding_lookup() for the word embeddings. + scope: (optional) variable scope. Defaults to "bert". + + Raises: + ValueError: The config is invalid or one of the input tensor shapes + is invalid. + """ + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + input_shape = get_shape_list(input_ids, expected_rank=2) + batch_size = input_shape[0] + seq_length = input_shape[1] + + if input_mask is None: + input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) + + if token_type_ids is None: + token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) + + with tf.compat.v1.variable_scope(scope, default_name="bert"): + with tf.compat.v1.variable_scope("embeddings"): + # Perform embedding lookup on the word ids. + self.embedding_output, self.embedding_table = embedding_lookup( + input_ids=input_ids, + vocab_size=config.vocab_size, + embedding_size=config.hidden_size, + initializer_range=config.initializer_range, + word_embedding_name="word_embeddings", + use_one_hot_embeddings=use_one_hot_embeddings, + ) + + # Add positional embeddings and token type embeddings, then layer + # normalize and perform dropout. + self.embedding_output = embedding_postprocessor( + input_tensor=self.embedding_output, + use_token_type=True, + token_type_ids=token_type_ids, + token_type_vocab_size=config.type_vocab_size, + token_type_embedding_name="token_type_embeddings", + use_position_embeddings=True, + position_embedding_name="position_embeddings", + initializer_range=config.initializer_range, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob, + ) + + with tf.compat.v1.variable_scope("encoder"): + # This converts a 2D mask of shape [batch_size, seq_length] to a 3D + # mask of shape [batch_size, seq_length, seq_length] which is used + # for the attention scores. + attention_mask = create_attention_mask_from_input_mask( + input_ids, input_mask + ) + + # Run the stacked transformer. + # `sequence_output` shape = [batch_size, seq_length, hidden_size]. + self.all_encoder_layers = transformer_model( + input_tensor=self.embedding_output, + attention_mask=attention_mask, + hidden_size=config.hidden_size, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + intermediate_act_fn=get_activation(config.hidden_act), + hidden_dropout_prob=config.hidden_dropout_prob, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + initializer_range=config.initializer_range, + do_return_all_layers=True, + ) + + self.sequence_output = self.all_encoder_layers[-1] + # The "pooler" converts the encoded sequence tensor of shape + # [batch_size, seq_length, hidden_size] to a tensor of shape + # [batch_size, hidden_size]. This is necessary for segment-level + # (or segment-pair-level) classification tasks where we need a fixed + # dimensional representation of the segment. + with tf.compat.v1.variable_scope("pooler"): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. We assume that this has been pre-trained + first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) + self.pooled_output = tf.compat.v1.layers.dense( + first_token_tensor, + config.hidden_size, + activation=tf.tanh, + kernel_initializer=create_initializer(config.initializer_range), + ) + + def get_pooled_output(self): + return self.pooled_output + + def get_sequence_output(self): + """Gets final hidden layer of encoder. + + Returns: + float Tensor of shape [batch_size, seq_length, hidden_size] corresponding + to the final hidden of the transformer encoder. + """ + return self.sequence_output + + def get_all_encoder_layers(self): + return self.all_encoder_layers + + def get_embedding_output(self): + """Gets output of the embedding lookup (i.e., input to the transformer). + + Returns: + float Tensor of shape [batch_size, seq_length, hidden_size] corresponding + to the output of the embedding layer, after summing the word + embeddings with the positional embeddings and the token type embeddings, + then performing layer normalization. This is the input to the transformer. + """ + return self.embedding_output + + def get_embedding_table(self): + return self.embedding_table + + +def gelu(x): + """Gaussian Error Linear Unit. - def get_embedding_output(self): - """Gets output of the embedding lookup (i.e., input to the transformer). + This is a smoother version of the RELU. + Original paper: https://arxiv.org/abs/1606.08415 + Args: + x: float Tensor to perform activation. Returns: - float Tensor of shape [batch_size, seq_length, hidden_size] corresponding - to the output of the embedding layer, after summing the word - embeddings with the positional embeddings and the token type embeddings, - then performing layer normalization. This is the input to the transformer. + `x` with the GELU activation applied. """ - return self.embedding_output + cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) + return x * cdf - def get_embedding_table(self): - return self.embedding_table +def get_activation(activation_string): + """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. -def gelu(x): - """Gaussian Error Linear Unit. + Args: + activation_string: String name of the activation function. - This is a smoother version of the RELU. - Original paper: https://arxiv.org/abs/1606.08415 - Args: - x: float Tensor to perform activation. + Returns: + A Python function corresponding to the activation function. If + `activation_string` is None, empty, or "linear", this will return None. + If `activation_string` is not a string, it will return `activation_string`. + + Raises: + ValueError: The `activation_string` does not correspond to a known + activation. + """ - Returns: - `x` with the GELU activation applied. - """ - cdf = 0.5 * (1.0 + tf.tanh( - (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf + # We assume that anything that"s not a string is already an activation + # function, so we just return it. + if not isinstance(activation_string, six.string_types): + return activation_string + if not activation_string: + return None -def get_activation(activation_string): - """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. - - Args: - activation_string: String name of the activation function. - - Returns: - A Python function corresponding to the activation function. If - `activation_string` is None, empty, or "linear", this will return None. - If `activation_string` is not a string, it will return `activation_string`. - - Raises: - ValueError: The `activation_string` does not correspond to a known - activation. - """ - - # We assume that anything that"s not a string is already an activation - # function, so we just return it. - if not isinstance(activation_string, six.string_types): - return activation_string - - if not activation_string: - return None - - act = activation_string.lower() - if act == "linear": - return None - elif act == "relu": - return tf.nn.relu - elif act == "gelu": - return gelu - elif act == "tanh": - return tf.tanh - else: - raise ValueError("Unsupported activation: %s" % act) + act = activation_string.lower() + if act == "linear": + return None + elif act == "relu": + return tf.nn.relu + elif act == "gelu": + return gelu + elif act == "tanh": + return tf.tanh + else: + raise ValueError("Unsupported activation: %s" % act) def get_assignment_map_from_checkpoint(tvars, init_checkpoint): - """Compute the union of the current variables and checkpoint variables.""" - assignment_map = {} - initialized_variable_names = {} + """Compute the union of the current variables and checkpoint variables.""" + assignment_map = {} + initialized_variable_names = {} - name_to_variable = collections.OrderedDict() - for var in tvars: - name = var.name - m = re.match("^(.*):\\d+$", name) - if m is not None: - name = m.group(1) - name_to_variable[name] = var + name_to_variable = collections.OrderedDict() + for var in tvars: + name = var.name + m = re.match("^(.*):\\d+$", name) + if m is not None: + name = m.group(1) + name_to_variable[name] = var - init_vars = tf.train.list_variables(init_checkpoint) + init_vars = tf.train.list_variables(init_checkpoint) - assignment_map = collections.OrderedDict() - for x in init_vars: - (name, var) = (x[0], x[1]) - if name not in name_to_variable: - continue - assignment_map[name] = name - initialized_variable_names[name] = 1 - initialized_variable_names[name + ":0"] = 1 + assignment_map = collections.OrderedDict() + for x in init_vars: + name, var = (x[0], x[1]) + if name not in name_to_variable: + continue + assignment_map[name] = name + initialized_variable_names[name] = 1 + initialized_variable_names[name + ":0"] = 1 - return (assignment_map, initialized_variable_names) + return (assignment_map, initialized_variable_names) def dropout(input_tensor, dropout_prob): - """Perform dropout. + """Perform dropout. - Args: - input_tensor: float Tensor. - dropout_prob: Python float. The probability of dropping out a value (NOT of - *keeping* a dimension as in `tf.nn.dropout`). + Args: + input_tensor: float Tensor. + dropout_prob: Python float. The probability of dropping out a value (NOT of + *keeping* a dimension as in `tf.nn.dropout`). - Returns: - A version of `input_tensor` with dropout applied. - """ - if dropout_prob is None or dropout_prob == 0.0: - return input_tensor + Returns: + A version of `input_tensor` with dropout applied. + """ + if dropout_prob is None or dropout_prob == 0.0: + return input_tensor - output = tf.nn.dropout(input_tensor, rate=1 - (1.0 - dropout_prob)) - return output + output = tf.nn.dropout(input_tensor, rate=1 - (1.0 - dropout_prob)) + return output def layer_norm(input_tensor, name=None): - """Run layer normalization on the last dimension of the tensor.""" - return tf.contrib.layers.layer_norm( - inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) + """Run layer normalization on the last dimension of the tensor.""" + return tf.contrib.layers.layer_norm( + inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name + ) def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): - """Runs layer normalization followed by dropout.""" - output_tensor = layer_norm(input_tensor, name) - output_tensor = dropout(output_tensor, dropout_prob) - return output_tensor + """Runs layer normalization followed by dropout.""" + output_tensor = layer_norm(input_tensor, name) + output_tensor = dropout(output_tensor, dropout_prob) + return output_tensor def create_initializer(initializer_range=0.02): - """Creates a `truncated_normal_initializer` with the given range.""" - return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range) - - -def embedding_lookup(input_ids, - vocab_size, - embedding_size=128, - initializer_range=0.02, - word_embedding_name="word_embeddings", - use_one_hot_embeddings=False): - """Looks up words embeddings for id tensor. - - Args: - input_ids: int32 Tensor of shape [batch_size, seq_length] containing word - ids. - vocab_size: int. Size of the embedding vocabulary. - embedding_size: int. Width of the word embeddings. - initializer_range: float. Embedding initialization range. - word_embedding_name: string. Name of the embedding table. - use_one_hot_embeddings: bool. If True, use one-hot method for word - embeddings. If False, use `tf.gather()`. - - Returns: - float Tensor of shape [batch_size, seq_length, embedding_size]. - """ - # This function assumes that the input is of shape [batch_size, seq_length, - # num_inputs]. - # - # If the input is a 2D tensor of shape [batch_size, seq_length], we - # reshape to [batch_size, seq_length, 1]. - if input_ids.shape.ndims == 2: - input_ids = tf.expand_dims(input_ids, axis=[-1]) - - embedding_table = tf.compat.v1.get_variable( - name=word_embedding_name, - shape=[vocab_size, embedding_size], - initializer=create_initializer(initializer_range)) - - flat_input_ids = tf.reshape(input_ids, [-1]) - if use_one_hot_embeddings: - one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) - output = tf.matmul(one_hot_input_ids, embedding_table) - else: - output = tf.gather(embedding_table, flat_input_ids) - - input_shape = get_shape_list(input_ids) - - output = tf.reshape(output, - input_shape[0:-1] + [input_shape[-1] * embedding_size]) - return (output, embedding_table) - - -def embedding_postprocessor(input_tensor, - use_token_type=False, - token_type_ids=None, - token_type_vocab_size=16, - token_type_embedding_name="token_type_embeddings", - use_position_embeddings=True, - position_embedding_name="position_embeddings", - initializer_range=0.02, - max_position_embeddings=512, - dropout_prob=0.1): - """Performs various post-processing on a word embedding tensor. - - Args: - input_tensor: float Tensor of shape [batch_size, seq_length, - embedding_size]. - use_token_type: bool. Whether to add embeddings for `token_type_ids`. - token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. - Must be specified if `use_token_type` is True. - token_type_vocab_size: int. The vocabulary size of `token_type_ids`. - token_type_embedding_name: string. The name of the embedding table variable - for token type ids. - use_position_embeddings: bool. Whether to add position embeddings for the - position of each token in the sequence. - position_embedding_name: string. The name of the embedding table variable - for positional embeddings. - initializer_range: float. Range of the weight initialization. - max_position_embeddings: int. Maximum sequence length that might ever be - used with this model. This can be longer than the sequence length of - input_tensor, but cannot be shorter. - dropout_prob: float. Dropout probability applied to the final output tensor. - - Returns: - float tensor with same shape as `input_tensor`. - - Raises: - ValueError: One of the tensor shapes or input values is invalid. - """ - input_shape = get_shape_list(input_tensor, expected_rank=3) - batch_size = input_shape[0] - seq_length = input_shape[1] - width = input_shape[2] - - output = input_tensor - - if use_token_type: - if token_type_ids is None: - raise ValueError("`token_type_ids` must be specified if" - "`use_token_type` is True.") - token_type_table = tf.compat.v1.get_variable( - name=token_type_embedding_name, - shape=[token_type_vocab_size, width], - initializer=create_initializer(initializer_range)) - # This vocab will be small so we always do one-hot here, since it is always - # faster for a small vocabulary. - flat_token_type_ids = tf.reshape(token_type_ids, [-1]) - one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) - token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) - token_type_embeddings = tf.reshape(token_type_embeddings, - [batch_size, seq_length, width]) - output += token_type_embeddings - - if use_position_embeddings: - assert_op = tf.compat.v1.assert_less_equal(seq_length, max_position_embeddings) - with tf.control_dependencies([assert_op]): - full_position_embeddings = tf.compat.v1.get_variable( - name=position_embedding_name, - shape=[max_position_embeddings, width], - initializer=create_initializer(initializer_range)) - # Since the position embedding table is a learned variable, we create it - # using a (long) sequence length `max_position_embeddings`. The actual - # sequence length might be shorter than this, for faster training of - # tasks that do not have long sequences. - # - # So `full_position_embeddings` is effectively an embedding table - # for position [0, 1, 2, ..., max_position_embeddings-1], and the current - # sequence has positions [0, 1, 2, ... seq_length-1], so we can just - # perform a slice. - position_embeddings = tf.slice(full_position_embeddings, [0, 0], - [seq_length, -1]) - num_dims = len(output.shape.as_list()) - - # Only the last two dimensions are relevant (`seq_length` and `width`), so - # we broadcast among the first dimensions, which is typically just - # the batch size. - position_broadcast_shape = [] - for _ in range(num_dims - 2): - position_broadcast_shape.append(1) - position_broadcast_shape.extend([seq_length, width]) - position_embeddings = tf.reshape(position_embeddings, - position_broadcast_shape) - output += position_embeddings - - output = layer_norm_and_dropout(output, dropout_prob) - return output + """Creates a `truncated_normal_initializer` with the given range.""" + return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range) -def create_attention_mask_from_input_mask(from_tensor, to_mask): - """Create 3D attention mask from a 2D tensor mask. - - Args: - from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. - to_mask: int32 Tensor of shape [batch_size, to_seq_length]. - - Returns: - float Tensor of shape [batch_size, from_seq_length, to_seq_length]. - """ - from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) - batch_size = from_shape[0] - from_seq_length = from_shape[1] - - to_shape = get_shape_list(to_mask, expected_rank=2) - to_seq_length = to_shape[1] - - to_mask = tf.cast( - tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) - - # We don't assume that `from_tensor` is a mask (although it could be). We - # don't actually care if we attend *from* padding tokens (only *to* padding) - # tokens so we create a tensor of all ones. - # - # `broadcast_ones` = [batch_size, from_seq_length, 1] - broadcast_ones = tf.ones( - shape=[batch_size, from_seq_length, 1], dtype=tf.float32) - - # Here we broadcast along two dimensions to create the mask. - mask = broadcast_ones * to_mask - - return mask - - -def attention_layer(from_tensor, - to_tensor, - attention_mask=None, - num_attention_heads=1, - size_per_head=512, - query_act=None, - key_act=None, - value_act=None, - attention_probs_dropout_prob=0.0, - initializer_range=0.02, - do_return_2d_tensor=False, - batch_size=None, - from_seq_length=None, - to_seq_length=None): - """Performs multi-headed attention from `from_tensor` to `to_tensor`. - - This is an implementation of multi-headed attention based on "Attention - is all you Need". If `from_tensor` and `to_tensor` are the same, then - this is self-attention. Each timestep in `from_tensor` attends to the - corresponding sequence in `to_tensor`, and returns a fixed-with vector. - - This function first projects `from_tensor` into a "query" tensor and - `to_tensor` into "key" and "value" tensors. These are (effectively) a list - of tensors of length `num_attention_heads`, where each tensor is of shape - [batch_size, seq_length, size_per_head]. - - Then, the query and key tensors are dot-producted and scaled. These are - softmaxed to obtain attention probabilities. The value tensors are then - interpolated by these probabilities, then concatenated back to a single - tensor and returned. - - In practice, the multi-headed attention are done with transposes and - reshapes rather than actual separate tensors. - - Args: - from_tensor: float Tensor of shape [batch_size, from_seq_length, - from_width]. - to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. - attention_mask: (optional) int32 Tensor of shape [batch_size, - from_seq_length, to_seq_length]. The values should be 1 or 0. The - attention scores will effectively be set to -infinity for any positions in - the mask that are 0, and will be unchanged for positions that are 1. - num_attention_heads: int. Number of attention heads. - size_per_head: int. Size of each attention head. - query_act: (optional) Activation function for the query transform. - key_act: (optional) Activation function for the key transform. - value_act: (optional) Activation function for the value transform. - attention_probs_dropout_prob: (optional) float. Dropout probability of the - attention probabilities. - initializer_range: float. Range of the weight initializer. - do_return_2d_tensor: bool. If True, the output will be of shape [batch_size - * from_seq_length, num_attention_heads * size_per_head]. If False, the - output will be of shape [batch_size, from_seq_length, num_attention_heads - * size_per_head]. - batch_size: (Optional) int. If the input is 2D, this might be the batch size - of the 3D version of the `from_tensor` and `to_tensor`. - from_seq_length: (Optional) If the input is 2D, this might be the seq length - of the 3D version of the `from_tensor`. - to_seq_length: (Optional) If the input is 2D, this might be the seq length - of the 3D version of the `to_tensor`. - - Returns: - float Tensor of shape [batch_size, from_seq_length, - num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is - true, this will be of shape [batch_size * from_seq_length, - num_attention_heads * size_per_head]). - - Raises: - ValueError: Any of the arguments or tensor shapes are invalid. - """ - - def transpose_for_scores(input_tensor, batch_size, num_attention_heads, - seq_length, width): - output_tensor = tf.reshape( - input_tensor, [batch_size, seq_length, num_attention_heads, width]) - - output_tensor = tf.transpose(a=output_tensor, perm=[0, 2, 1, 3]) - return output_tensor +def embedding_lookup( + input_ids, + vocab_size, + embedding_size=128, + initializer_range=0.02, + word_embedding_name="word_embeddings", + use_one_hot_embeddings=False, +): + """Looks up words embeddings for id tensor. + + Args: + input_ids: int32 Tensor of shape [batch_size, seq_length] containing word + ids. + vocab_size: int. Size of the embedding vocabulary. + embedding_size: int. Width of the word embeddings. + initializer_range: float. Embedding initialization range. + word_embedding_name: string. Name of the embedding table. + use_one_hot_embeddings: bool. If True, use one-hot method for word + embeddings. If False, use `tf.gather()`. + + Returns: + float Tensor of shape [batch_size, seq_length, embedding_size]. + """ + # This function assumes that the input is of shape [batch_size, seq_length, + # num_inputs]. + # + # If the input is a 2D tensor of shape [batch_size, seq_length], we + # reshape to [batch_size, seq_length, 1]. + if input_ids.shape.ndims == 2: + input_ids = tf.expand_dims(input_ids, axis=[-1]) + + embedding_table = tf.compat.v1.get_variable( + name=word_embedding_name, + shape=[vocab_size, embedding_size], + initializer=create_initializer(initializer_range), + ) + + flat_input_ids = tf.reshape(input_ids, [-1]) + if use_one_hot_embeddings: + one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) + output = tf.matmul(one_hot_input_ids, embedding_table) + else: + output = tf.gather(embedding_table, flat_input_ids) + + input_shape = get_shape_list(input_ids) + + output = tf.reshape(output, input_shape[0:-1] + [input_shape[-1] * embedding_size]) + return (output, embedding_table) + + +def embedding_postprocessor( + input_tensor, + use_token_type=False, + token_type_ids=None, + token_type_vocab_size=16, + token_type_embedding_name="token_type_embeddings", + use_position_embeddings=True, + position_embedding_name="position_embeddings", + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1, +): + """Performs various post-processing on a word embedding tensor. + + Args: + input_tensor: float Tensor of shape [batch_size, seq_length, + embedding_size]. + use_token_type: bool. Whether to add embeddings for `token_type_ids`. + token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. + Must be specified if `use_token_type` is True. + token_type_vocab_size: int. The vocabulary size of `token_type_ids`. + token_type_embedding_name: string. The name of the embedding table variable + for token type ids. + use_position_embeddings: bool. Whether to add position embeddings for the + position of each token in the sequence. + position_embedding_name: string. The name of the embedding table variable + for positional embeddings. + initializer_range: float. Range of the weight initialization. + max_position_embeddings: int. Maximum sequence length that might ever be + used with this model. This can be longer than the sequence length of + input_tensor, but cannot be shorter. + dropout_prob: float. Dropout probability applied to the final output tensor. + + Returns: + float tensor with same shape as `input_tensor`. + + Raises: + ValueError: One of the tensor shapes or input values is invalid. + """ + input_shape = get_shape_list(input_tensor, expected_rank=3) + batch_size = input_shape[0] + seq_length = input_shape[1] + width = input_shape[2] + + output = input_tensor + + if use_token_type: + if token_type_ids is None: + raise ValueError( + "`token_type_ids` must be specified if" "`use_token_type` is True." + ) + token_type_table = tf.compat.v1.get_variable( + name=token_type_embedding_name, + shape=[token_type_vocab_size, width], + initializer=create_initializer(initializer_range), + ) + # This vocab will be small so we always do one-hot here, since it is always + # faster for a small vocabulary. + flat_token_type_ids = tf.reshape(token_type_ids, [-1]) + one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) + token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) + token_type_embeddings = tf.reshape( + token_type_embeddings, [batch_size, seq_length, width] + ) + output += token_type_embeddings + + if use_position_embeddings: + assert_op = tf.compat.v1.assert_less_equal(seq_length, max_position_embeddings) + with tf.control_dependencies([assert_op]): + full_position_embeddings = tf.compat.v1.get_variable( + name=position_embedding_name, + shape=[max_position_embeddings, width], + initializer=create_initializer(initializer_range), + ) + # Since the position embedding table is a learned variable, we create it + # using a (long) sequence length `max_position_embeddings`. The actual + # sequence length might be shorter than this, for faster training of + # tasks that do not have long sequences. + # + # So `full_position_embeddings` is effectively an embedding table + # for position [0, 1, 2, ..., max_position_embeddings-1], and the current + # sequence has positions [0, 1, 2, ... seq_length-1], so we can just + # perform a slice. + position_embeddings = tf.slice( + full_position_embeddings, [0, 0], [seq_length, -1] + ) + num_dims = len(output.shape.as_list()) + + # Only the last two dimensions are relevant (`seq_length` and `width`), so + # we broadcast among the first dimensions, which is typically just + # the batch size. + position_broadcast_shape = [] + for _ in range(num_dims - 2): + position_broadcast_shape.append(1) + position_broadcast_shape.extend([seq_length, width]) + position_embeddings = tf.reshape( + position_embeddings, position_broadcast_shape + ) + output += position_embeddings + + output = layer_norm_and_dropout(output, dropout_prob) + return output + - from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) - to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) +def create_attention_mask_from_input_mask(from_tensor, to_mask): + """Create 3D attention mask from a 2D tensor mask. - if len(from_shape) != len(to_shape): - raise ValueError( - "The rank of `from_tensor` must match the rank of `to_tensor`.") + Args: + from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. + to_mask: int32 Tensor of shape [batch_size, to_seq_length]. - if len(from_shape) == 3: + Returns: + float Tensor of shape [batch_size, from_seq_length, to_seq_length]. + """ + from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) batch_size = from_shape[0] from_seq_length = from_shape[1] + + to_shape = get_shape_list(to_mask, expected_rank=2) to_seq_length = to_shape[1] - elif len(from_shape) == 2: - if (batch_size is None or from_seq_length is None or to_seq_length is None): - raise ValueError( - "When passing in rank 2 tensors to attention_layer, the values " - "for `batch_size`, `from_seq_length`, and `to_seq_length` " - "must all be specified.") - - # Scalar dimensions referenced here: - # B = batch size (number of sequences) - # F = `from_tensor` sequence length - # T = `to_tensor` sequence length - # N = `num_attention_heads` - # H = `size_per_head` - - from_tensor_2d = reshape_to_matrix(from_tensor) - to_tensor_2d = reshape_to_matrix(to_tensor) - - # `query_layer` = [B*F, N*H] - query_layer = tf.compat.v1.layers.dense( - from_tensor_2d, - num_attention_heads * size_per_head, - activation=query_act, - name="query", - kernel_initializer=create_initializer(initializer_range)) - - # `key_layer` = [B*T, N*H] - key_layer = tf.compat.v1.layers.dense( - to_tensor_2d, - num_attention_heads * size_per_head, - activation=key_act, - name="key", - kernel_initializer=create_initializer(initializer_range)) - - # `value_layer` = [B*T, N*H] - value_layer = tf.compat.v1.layers.dense( - to_tensor_2d, - num_attention_heads * size_per_head, - activation=value_act, - name="value", - kernel_initializer=create_initializer(initializer_range)) - - # `query_layer` = [B, N, F, H] - query_layer = transpose_for_scores(query_layer, batch_size, - num_attention_heads, from_seq_length, - size_per_head) - - # `key_layer` = [B, N, T, H] - key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, - to_seq_length, size_per_head) - - # Take the dot product between "query" and "key" to get the raw - # attention scores. - # `attention_scores` = [B, N, F, T] - attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) - attention_scores = tf.multiply(attention_scores, - 1.0 / math.sqrt(float(size_per_head))) - - if attention_mask is not None: - # `attention_mask` = [B, 1, F, T] - attention_mask = tf.expand_dims(attention_mask, axis=[1]) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 - - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_scores += adder - - # Normalize the attention scores to probabilities. - # `attention_probs` = [B, N, F, T] - attention_probs = tf.nn.softmax(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = dropout(attention_probs, attention_probs_dropout_prob) - - # `value_layer` = [B, T, N, H] - value_layer = tf.reshape( - value_layer, - [batch_size, to_seq_length, num_attention_heads, size_per_head]) - - # `value_layer` = [B, N, T, H] - value_layer = tf.transpose(a=value_layer, perm=[0, 2, 1, 3]) - - # `context_layer` = [B, N, F, H] - context_layer = tf.matmul(attention_probs, value_layer) - - # `context_layer` = [B, F, N, H] - context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3]) - - if do_return_2d_tensor: - # `context_layer` = [B*F, N*H] - context_layer = tf.reshape( - context_layer, - [batch_size * from_seq_length, num_attention_heads * size_per_head]) - else: - # `context_layer` = [B, F, N*H] - context_layer = tf.reshape( - context_layer, - [batch_size, from_seq_length, num_attention_heads * size_per_head]) - - return context_layer - - -def transformer_model(input_tensor, - attention_mask=None, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - intermediate_act_fn=gelu, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - initializer_range=0.02, - do_return_all_layers=False): - """Multi-headed, multi-layer Transformer from "Attention is All You Need". - - This is almost an exact implementation of the original Transformer encoder. - - See the original paper: - https://arxiv.org/abs/1706.03762 - - Also see: - https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py - - Args: - input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. - attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, - seq_length], with 1 for positions that can be attended to and 0 in - positions that should not be. - hidden_size: int. Hidden size of the Transformer. - num_hidden_layers: int. Number of layers (blocks) in the Transformer. - num_attention_heads: int. Number of attention heads in the Transformer. - intermediate_size: int. The size of the "intermediate" (a.k.a., feed - forward) layer. - intermediate_act_fn: function. The non-linear activation function to apply - to the output of the intermediate/feed-forward layer. - hidden_dropout_prob: float. Dropout probability for the hidden layers. - attention_probs_dropout_prob: float. Dropout probability of the attention - probabilities. - initializer_range: float. Range of the initializer (stddev of truncated - normal). - do_return_all_layers: Whether to also return all layers or just the final - layer. - - Returns: - float Tensor of shape [batch_size, seq_length, hidden_size], the final - hidden layer of the Transformer. - - Raises: - ValueError: A Tensor shape or parameter is invalid. - """ - if hidden_size % num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads)) - - attention_head_size = int(hidden_size / num_attention_heads) - input_shape = get_shape_list(input_tensor, expected_rank=3) - batch_size = input_shape[0] - seq_length = input_shape[1] - input_width = input_shape[2] - - # The Transformer performs sum residuals on all layers so the input needs - # to be the same as the hidden size. - if input_width != hidden_size: - raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % - (input_width, hidden_size)) - - # We keep the representation as a 2D tensor to avoid re-shaping it back and - # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on - # the GPU/CPU but may not be free on the TPU, so we want to minimize them to - # help the optimizer. - prev_output = reshape_to_matrix(input_tensor) - - all_layer_outputs = [] - for layer_idx in range(num_hidden_layers): - with tf.compat.v1.variable_scope("layer_%d" % layer_idx): - layer_input = prev_output - - with tf.compat.v1.variable_scope("attention"): - attention_heads = [] - with tf.compat.v1.variable_scope("self"): - attention_head = attention_layer( - from_tensor=layer_input, - to_tensor=layer_input, - attention_mask=attention_mask, - num_attention_heads=num_attention_heads, - size_per_head=attention_head_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - initializer_range=initializer_range, - do_return_2d_tensor=True, - batch_size=batch_size, - from_seq_length=seq_length, - to_seq_length=seq_length) - attention_heads.append(attention_head) - - attention_output = None - if len(attention_heads) == 1: - attention_output = attention_heads[0] - else: - # In the case where we have other sequences, we just concatenate - # them to the self-attention head before the projection. - attention_output = tf.concat(attention_heads, axis=-1) - - # Run a linear projection of `hidden_size` then add a residual - # with `layer_input`. - with tf.compat.v1.variable_scope("output"): - attention_output = tf.compat.v1.layers.dense( - attention_output, - hidden_size, - kernel_initializer=create_initializer(initializer_range)) - attention_output = dropout(attention_output, hidden_dropout_prob) - attention_output = layer_norm(attention_output + layer_input) - - # The activation is only applied to the "intermediate" hidden layer. - with tf.compat.v1.variable_scope("intermediate"): - intermediate_output = tf.compat.v1.layers.dense( - attention_output, - intermediate_size, - activation=intermediate_act_fn, - kernel_initializer=create_initializer(initializer_range)) - - # Down-project back to `hidden_size` then add the residual. - with tf.compat.v1.variable_scope("output"): - layer_output = tf.compat.v1.layers.dense( - intermediate_output, - hidden_size, - kernel_initializer=create_initializer(initializer_range)) - layer_output = dropout(layer_output, hidden_dropout_prob) - layer_output = layer_norm(layer_output + attention_output) - prev_output = layer_output - all_layer_outputs.append(layer_output) - - if do_return_all_layers: - final_outputs = [] - for layer_output in all_layer_outputs: - final_output = reshape_from_matrix(layer_output, input_shape) - final_outputs.append(final_output) - return final_outputs - else: - final_output = reshape_from_matrix(prev_output, input_shape) - return final_output + + to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) + + # We don't assume that `from_tensor` is a mask (although it could be). We + # don't actually care if we attend *from* padding tokens (only *to* padding) + # tokens so we create a tensor of all ones. + # + # `broadcast_ones` = [batch_size, from_seq_length, 1] + broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=tf.float32) + + # Here we broadcast along two dimensions to create the mask. + mask = broadcast_ones * to_mask + + return mask + + +def attention_layer( + from_tensor, + to_tensor, + attention_mask=None, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + do_return_2d_tensor=False, + batch_size=None, + from_seq_length=None, + to_seq_length=None, +): + """Performs multi-headed attention from `from_tensor` to `to_tensor`. + + This is an implementation of multi-headed attention based on "Attention + is all you Need". If `from_tensor` and `to_tensor` are the same, then + this is self-attention. Each timestep in `from_tensor` attends to the + corresponding sequence in `to_tensor`, and returns a fixed-with vector. + + This function first projects `from_tensor` into a "query" tensor and + `to_tensor` into "key" and "value" tensors. These are (effectively) a list + of tensors of length `num_attention_heads`, where each tensor is of shape + [batch_size, seq_length, size_per_head]. + + Then, the query and key tensors are dot-producted and scaled. These are + softmaxed to obtain attention probabilities. The value tensors are then + interpolated by these probabilities, then concatenated back to a single + tensor and returned. + + In practice, the multi-headed attention are done with transposes and + reshapes rather than actual separate tensors. + + Args: + from_tensor: float Tensor of shape [batch_size, from_seq_length, + from_width]. + to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. + attention_mask: (optional) int32 Tensor of shape [batch_size, + from_seq_length, to_seq_length]. The values should be 1 or 0. The + attention scores will effectively be set to -infinity for any positions in + the mask that are 0, and will be unchanged for positions that are 1. + num_attention_heads: int. Number of attention heads. + size_per_head: int. Size of each attention head. + query_act: (optional) Activation function for the query transform. + key_act: (optional) Activation function for the key transform. + value_act: (optional) Activation function for the value transform. + attention_probs_dropout_prob: (optional) float. Dropout probability of the + attention probabilities. + initializer_range: float. Range of the weight initializer. + do_return_2d_tensor: bool. If True, the output will be of shape [batch_size + * from_seq_length, num_attention_heads * size_per_head]. If False, the + output will be of shape [batch_size, from_seq_length, num_attention_heads + * size_per_head]. + batch_size: (Optional) int. If the input is 2D, this might be the batch size + of the 3D version of the `from_tensor` and `to_tensor`. + from_seq_length: (Optional) If the input is 2D, this might be the seq length + of the 3D version of the `from_tensor`. + to_seq_length: (Optional) If the input is 2D, this might be the seq length + of the 3D version of the `to_tensor`. + + Returns: + float Tensor of shape [batch_size, from_seq_length, + num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is + true, this will be of shape [batch_size * from_seq_length, + num_attention_heads * size_per_head]). + + Raises: + ValueError: Any of the arguments or tensor shapes are invalid. + """ + + def transpose_for_scores( + input_tensor, batch_size, num_attention_heads, seq_length, width + ): + output_tensor = tf.reshape( + input_tensor, [batch_size, seq_length, num_attention_heads, width] + ) + + output_tensor = tf.transpose(a=output_tensor, perm=[0, 2, 1, 3]) + return output_tensor + + from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) + to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) + + if len(from_shape) != len(to_shape): + raise ValueError( + "The rank of `from_tensor` must match the rank of `to_tensor`." + ) + + if len(from_shape) == 3: + batch_size = from_shape[0] + from_seq_length = from_shape[1] + to_seq_length = to_shape[1] + elif len(from_shape) == 2: + if batch_size is None or from_seq_length is None or to_seq_length is None: + raise ValueError( + "When passing in rank 2 tensors to attention_layer, the values " + "for `batch_size`, `from_seq_length`, and `to_seq_length` " + "must all be specified." + ) + + # Scalar dimensions referenced here: + # B = batch size (number of sequences) + # F = `from_tensor` sequence length + # T = `to_tensor` sequence length + # N = `num_attention_heads` + # H = `size_per_head` + + from_tensor_2d = reshape_to_matrix(from_tensor) + to_tensor_2d = reshape_to_matrix(to_tensor) + + # `query_layer` = [B*F, N*H] + query_layer = tf.compat.v1.layers.dense( + from_tensor_2d, + num_attention_heads * size_per_head, + activation=query_act, + name="query", + kernel_initializer=create_initializer(initializer_range), + ) + + # `key_layer` = [B*T, N*H] + key_layer = tf.compat.v1.layers.dense( + to_tensor_2d, + num_attention_heads * size_per_head, + activation=key_act, + name="key", + kernel_initializer=create_initializer(initializer_range), + ) + + # `value_layer` = [B*T, N*H] + value_layer = tf.compat.v1.layers.dense( + to_tensor_2d, + num_attention_heads * size_per_head, + activation=value_act, + name="value", + kernel_initializer=create_initializer(initializer_range), + ) + + # `query_layer` = [B, N, F, H] + query_layer = transpose_for_scores( + query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head + ) + + # `key_layer` = [B, N, T, H] + key_layer = transpose_for_scores( + key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head + ) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + # `attention_scores` = [B, N, F, T] + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + attention_scores = tf.multiply( + attention_scores, 1.0 / math.sqrt(float(size_per_head)) + ) + + if attention_mask is not None: + # `attention_mask` = [B, 1, F, T] + attention_mask = tf.expand_dims(attention_mask, axis=[1]) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 + + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_scores += adder + + # Normalize the attention scores to probabilities. + # `attention_probs` = [B, N, F, T] + attention_probs = tf.nn.softmax(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = dropout(attention_probs, attention_probs_dropout_prob) + + # `value_layer` = [B, T, N, H] + value_layer = tf.reshape( + value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head] + ) + + # `value_layer` = [B, N, T, H] + value_layer = tf.transpose(a=value_layer, perm=[0, 2, 1, 3]) + + # `context_layer` = [B, N, F, H] + context_layer = tf.matmul(attention_probs, value_layer) + + # `context_layer` = [B, F, N, H] + context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3]) + + if do_return_2d_tensor: + # `context_layer` = [B*F, N*H] + context_layer = tf.reshape( + context_layer, + [batch_size * from_seq_length, num_attention_heads * size_per_head], + ) + else: + # `context_layer` = [B, F, N*H] + context_layer = tf.reshape( + context_layer, + [batch_size, from_seq_length, num_attention_heads * size_per_head], + ) + + return context_layer + + +def transformer_model( + input_tensor, + attention_mask=None, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + intermediate_act_fn=gelu, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + do_return_all_layers=False, +): + """Multi-headed, multi-layer Transformer from "Attention is All You Need". + + This is almost an exact implementation of the original Transformer encoder. + + See the original paper: + https://arxiv.org/abs/1706.03762 + + Also see: + https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py + + Args: + input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. + attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, + seq_length], with 1 for positions that can be attended to and 0 in + positions that should not be. + hidden_size: int. Hidden size of the Transformer. + num_hidden_layers: int. Number of layers (blocks) in the Transformer. + num_attention_heads: int. Number of attention heads in the Transformer. + intermediate_size: int. The size of the "intermediate" (a.k.a., feed + forward) layer. + intermediate_act_fn: function. The non-linear activation function to apply + to the output of the intermediate/feed-forward layer. + hidden_dropout_prob: float. Dropout probability for the hidden layers. + attention_probs_dropout_prob: float. Dropout probability of the attention + probabilities. + initializer_range: float. Range of the initializer (stddev of truncated + normal). + do_return_all_layers: Whether to also return all layers or just the final + layer. + + Returns: + float Tensor of shape [batch_size, seq_length, hidden_size], the final + hidden layer of the Transformer. + + Raises: + ValueError: A Tensor shape or parameter is invalid. + """ + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads) + ) + + attention_head_size = int(hidden_size / num_attention_heads) + input_shape = get_shape_list(input_tensor, expected_rank=3) + batch_size = input_shape[0] + seq_length = input_shape[1] + input_width = input_shape[2] + + # The Transformer performs sum residuals on all layers so the input needs + # to be the same as the hidden size. + if input_width != hidden_size: + raise ValueError( + "The width of the input tensor (%d) != hidden size (%d)" + % (input_width, hidden_size) + ) + + # We keep the representation as a 2D tensor to avoid re-shaping it back and + # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on + # the GPU/CPU but may not be free on the TPU, so we want to minimize them to + # help the optimizer. + prev_output = reshape_to_matrix(input_tensor) + + all_layer_outputs = [] + for layer_idx in range(num_hidden_layers): + with tf.compat.v1.variable_scope("layer_%d" % layer_idx): + layer_input = prev_output + + with tf.compat.v1.variable_scope("attention"): + attention_heads = [] + with tf.compat.v1.variable_scope("self"): + attention_head = attention_layer( + from_tensor=layer_input, + to_tensor=layer_input, + attention_mask=attention_mask, + num_attention_heads=num_attention_heads, + size_per_head=attention_head_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + initializer_range=initializer_range, + do_return_2d_tensor=True, + batch_size=batch_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + ) + attention_heads.append(attention_head) + + attention_output = None + if len(attention_heads) == 1: + attention_output = attention_heads[0] + else: + # In the case where we have other sequences, we just concatenate + # them to the self-attention head before the projection. + attention_output = tf.concat(attention_heads, axis=-1) + + # Run a linear projection of `hidden_size` then add a residual + # with `layer_input`. + with tf.compat.v1.variable_scope("output"): + attention_output = tf.compat.v1.layers.dense( + attention_output, + hidden_size, + kernel_initializer=create_initializer(initializer_range), + ) + attention_output = dropout(attention_output, hidden_dropout_prob) + attention_output = layer_norm(attention_output + layer_input) + + # The activation is only applied to the "intermediate" hidden layer. + with tf.compat.v1.variable_scope("intermediate"): + intermediate_output = tf.compat.v1.layers.dense( + attention_output, + intermediate_size, + activation=intermediate_act_fn, + kernel_initializer=create_initializer(initializer_range), + ) + + # Down-project back to `hidden_size` then add the residual. + with tf.compat.v1.variable_scope("output"): + layer_output = tf.compat.v1.layers.dense( + intermediate_output, + hidden_size, + kernel_initializer=create_initializer(initializer_range), + ) + layer_output = dropout(layer_output, hidden_dropout_prob) + layer_output = layer_norm(layer_output + attention_output) + prev_output = layer_output + all_layer_outputs.append(layer_output) + + if do_return_all_layers: + final_outputs = [] + for layer_output in all_layer_outputs: + final_output = reshape_from_matrix(layer_output, input_shape) + final_outputs.append(final_output) + return final_outputs + else: + final_output = reshape_from_matrix(prev_output, input_shape) + return final_output def get_shape_list(tensor, expected_rank=None, name=None): - """Returns a list of the shape of tensor, preferring static dimensions. - - Args: - tensor: A tf.Tensor object to find the shape of. - expected_rank: (optional) int. The expected rank of `tensor`. If this is - specified and the `tensor` has a different rank, and exception will be - thrown. - name: Optional name of the tensor for the error message. - - Returns: - A list of dimensions of the shape of tensor. All static dimensions will - be returned as python integers, and dynamic dimensions will be returned - as tf.Tensor scalars. - """ - if name is None: - name = tensor.name - - if expected_rank is not None: - assert_rank(tensor, expected_rank, name) - - shape = tensor.shape.as_list() - - non_static_indexes = [] - for (index, dim) in enumerate(shape): - if dim is None: - non_static_indexes.append(index) - - if not non_static_indexes: - return shape + """Returns a list of the shape of tensor, preferring static dimensions. + + Args: + tensor: A tf.Tensor object to find the shape of. + expected_rank: (optional) int. The expected rank of `tensor`. If this is + specified and the `tensor` has a different rank, and exception will be + thrown. + name: Optional name of the tensor for the error message. - dyn_shape = tf.shape(input=tensor) - for index in non_static_indexes: - shape[index] = dyn_shape[index] - return shape + Returns: + A list of dimensions of the shape of tensor. All static dimensions will + be returned as python integers, and dynamic dimensions will be returned + as tf.Tensor scalars. + """ + if name is None: + name = tensor.name + if expected_rank is not None: + assert_rank(tensor, expected_rank, name) -def reshape_to_matrix(input_tensor): - """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" - ndims = input_tensor.shape.ndims - if ndims < 2: - raise ValueError("Input tensor must have at least rank 2. Shape = %s" % - (input_tensor.shape)) - if ndims == 2: - return input_tensor + shape = tensor.shape.as_list() - width = input_tensor.shape[-1] - output_tensor = tf.reshape(input_tensor, [-1, width]) - return output_tensor + non_static_indexes = [] + for index, dim in enumerate(shape): + if dim is None: + non_static_indexes.append(index) + if not non_static_indexes: + return shape -def reshape_from_matrix(output_tensor, orig_shape_list): - """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" - if len(orig_shape_list) == 2: + dyn_shape = tf.shape(input=tensor) + for index in non_static_indexes: + shape[index] = dyn_shape[index] + return shape + + +def reshape_to_matrix(input_tensor): + """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" + ndims = input_tensor.shape.ndims + if ndims < 2: + raise ValueError( + "Input tensor must have at least rank 2. Shape = %s" % (input_tensor.shape) + ) + if ndims == 2: + return input_tensor + + width = input_tensor.shape[-1] + output_tensor = tf.reshape(input_tensor, [-1, width]) return output_tensor - output_shape = get_shape_list(output_tensor) - orig_dims = orig_shape_list[0:-1] - width = output_shape[-1] +def reshape_from_matrix(output_tensor, orig_shape_list): + """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" + if len(orig_shape_list) == 2: + return output_tensor + + output_shape = get_shape_list(output_tensor) + + orig_dims = orig_shape_list[0:-1] + width = output_shape[-1] - return tf.reshape(output_tensor, orig_dims + [width]) + return tf.reshape(output_tensor, orig_dims + [width]) def assert_rank(tensor, expected_rank, name=None): - """Raises an exception if the tensor rank is not of the expected rank. - - Args: - tensor: A tf.Tensor to check the rank of. - expected_rank: Python integer or list of integers, expected rank. - name: Optional name of the tensor for the error message. - - Raises: - ValueError: If the expected shape doesn't match the actual shape. - """ - if name is None: - name = tensor.name - - expected_rank_dict = {} - if isinstance(expected_rank, six.integer_types): - expected_rank_dict[expected_rank] = True - else: - for x in expected_rank: - expected_rank_dict[x] = True - - actual_rank = tensor.shape.ndims - if actual_rank not in expected_rank_dict: - scope_name = tf.compat.v1.get_variable_scope().name - raise ValueError( - "For the tensor `%s` in scope `%s`, the actual rank " - "`%d` (shape = %s) is not equal to the expected rank `%s`" % - (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) + """Raises an exception if the tensor rank is not of the expected rank. + + Args: + tensor: A tf.Tensor to check the rank of. + expected_rank: Python integer or list of integers, expected rank. + name: Optional name of the tensor for the error message. + + Raises: + ValueError: If the expected shape doesn't match the actual shape. + """ + if name is None: + name = tensor.name + + expected_rank_dict = {} + if isinstance(expected_rank, six.integer_types): + expected_rank_dict[expected_rank] = True + else: + for x in expected_rank: + expected_rank_dict[x] = True + + actual_rank = tensor.shape.ndims + if actual_rank not in expected_rank_dict: + scope_name = tf.compat.v1.get_variable_scope().name + raise ValueError( + "For the tensor `%s` in scope `%s`, the actual rank " + "`%d` (shape = %s) is not equal to the expected rank `%s`" + % (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)) + ) diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling_test.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling_test.py index 92ed2cbc..131b1d8e 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling_test.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/modeling_test.py @@ -28,250 +28,265 @@ class BertModelTest(tf.test.TestCase): - class BertModelTester(object): - - def __init__(self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - initializer_range=0.02, - scope=None): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.scope = scope - - def create_model(self): - input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], - self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = BertModelTest.ids_tensor( - [self.batch_size, self.seq_length], vocab_size=2) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = BertModelTest.ids_tensor( - [self.batch_size, self.seq_length], self.type_vocab_size) - - config = modeling.BertConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - initializer_range=self.initializer_range) - - model = modeling.BertModel( - config=config, - is_training=self.is_training, - input_ids=input_ids, - input_mask=input_mask, - token_type_ids=token_type_ids, - scope=self.scope) - - outputs = { - "embedding_output": model.get_embedding_output(), - "sequence_output": model.get_sequence_output(), - "pooled_output": model.get_pooled_output(), - "all_encoder_layers": model.get_all_encoder_layers(), - } - return outputs - - def check_output(self, result): - self.parent.assertAllEqual( - result["embedding_output"].shape, - [self.batch_size, self.seq_length, self.hidden_size]) - - self.parent.assertAllEqual( - result["sequence_output"].shape, - [self.batch_size, self.seq_length, self.hidden_size]) - - self.parent.assertAllEqual(result["pooled_output"].shape, - [self.batch_size, self.hidden_size]) - - def test_default(self): - self.run_tester(BertModelTest.BertModelTester(self)) - - def test_config_to_json_string(self): - config = modeling.BertConfig(vocab_size=99, hidden_size=37) - obj = json.loads(config.to_json_string()) - self.assertEqual(obj["vocab_size"], 99) - self.assertEqual(obj["hidden_size"], 37) - - def run_tester(self, tester): - with self.test_session() as sess: - ops = tester.create_model() - init_op = tf.group(tf.compat.v1.global_variables_initializer(), - tf.compat.v1.local_variables_initializer()) - sess.run(init_op) - output_result = sess.run(ops) - tester.check_output(output_result) - - self.assert_all_tensors_reachable(sess, [init_op, ops]) - - @classmethod - def ids_tensor(cls, shape, vocab_size, rng=None, name=None): - """Creates a random int32 tensor of the shape within the vocab size.""" - if rng is None: - rng = random.Random() - - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(rng.randint(0, vocab_size - 1)) - - return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) - - def assert_all_tensors_reachable(self, sess, outputs): - """Checks that all the tensors in the graph are reachable from outputs.""" - graph = sess.graph - - ignore_strings = [ - "^.*/assert_less_equal/.*$", - "^.*/dilation_rate$", - "^.*/Tensordot/concat$", - "^.*/Tensordot/concat/axis$", - "^testing/.*$", - ] - - ignore_regexes = [re.compile(x) for x in ignore_strings] - - unreachable = self.get_unreachable_ops(graph, outputs) - filtered_unreachable = [] - for x in unreachable: - do_ignore = False - for r in ignore_regexes: - m = r.match(x.name) - if m is not None: - do_ignore = True - if do_ignore: - continue - filtered_unreachable.append(x) - unreachable = filtered_unreachable - - self.assertEqual( - len(unreachable), 0, "The following ops are unreachable: %s" % - (" ".join([x.name for x in unreachable]))) - - @classmethod - def get_unreachable_ops(cls, graph, outputs): - """Finds all of the tensors in graph that are unreachable from outputs.""" - outputs = cls.flatten_recursive(outputs) - output_to_op = collections.defaultdict(list) - op_to_all = collections.defaultdict(list) - assign_out_to_in = collections.defaultdict(list) - - for op in graph.get_operations(): - for x in op.inputs: - op_to_all[op.name].append(x.name) - for y in op.outputs: - output_to_op[y.name].append(op.name) - op_to_all[op.name].append(y.name) - if str(op.type) == "Assign": - for y in op.outputs: - for x in op.inputs: - assign_out_to_in[y.name].append(x.name) - - assign_groups = collections.defaultdict(list) - for out_name in assign_out_to_in.keys(): - name_group = assign_out_to_in[out_name] - for n1 in name_group: - assign_groups[n1].append(out_name) - for n2 in name_group: - if n1 != n2: - assign_groups[n1].append(n2) - - seen_tensors = {} - stack = [x.name for x in outputs] - while stack: - name = stack.pop() - if name in seen_tensors: - continue - seen_tensors[name] = True - - if name in output_to_op: - for op_name in output_to_op[name]: - if op_name in op_to_all: - for input_name in op_to_all[op_name]: - if input_name not in stack: - stack.append(input_name) - - expanded_names = [] - if name in assign_groups: - for assign_name in assign_groups[name]: - expanded_names.append(assign_name) - - for expanded_name in expanded_names: - if expanded_name not in stack: - stack.append(expanded_name) - - unreachable_ops = [] - for op in graph.get_operations(): - is_unreachable = False - all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] - for name in all_names: - if name not in seen_tensors: - is_unreachable = True - if is_unreachable: - unreachable_ops.append(op) - return unreachable_ops - - @classmethod - def flatten_recursive(cls, item): - """Flattens (potentially nested) a tuple/dictionary/list to a list.""" - output = [] - if isinstance(item, list): - output.extend(item) - elif isinstance(item, tuple): - output.extend(list(item)) - elif isinstance(item, dict): - for (_, v) in six.iteritems(item): - output.append(v) - else: - return [item] - - flat_output = [] - for x in output: - flat_output.extend(cls.flatten_recursive(x)) - return flat_output + class BertModelTester(object): + + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.scope = scope + + def create_model(self): + input_ids = BertModelTest.ids_tensor( + [self.batch_size, self.seq_length], self.vocab_size + ) + + input_mask = None + if self.use_input_mask: + input_mask = BertModelTest.ids_tensor( + [self.batch_size, self.seq_length], vocab_size=2 + ) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = BertModelTest.ids_tensor( + [self.batch_size, self.seq_length], self.type_vocab_size + ) + + config = modeling.BertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + ) + + model = modeling.BertModel( + config=config, + is_training=self.is_training, + input_ids=input_ids, + input_mask=input_mask, + token_type_ids=token_type_ids, + scope=self.scope, + ) + + outputs = { + "embedding_output": model.get_embedding_output(), + "sequence_output": model.get_sequence_output(), + "pooled_output": model.get_pooled_output(), + "all_encoder_layers": model.get_all_encoder_layers(), + } + return outputs + + def check_output(self, result): + self.parent.assertAllEqual( + result["embedding_output"].shape, + [self.batch_size, self.seq_length, self.hidden_size], + ) + + self.parent.assertAllEqual( + result["sequence_output"].shape, + [self.batch_size, self.seq_length, self.hidden_size], + ) + + self.parent.assertAllEqual( + result["pooled_output"].shape, [self.batch_size, self.hidden_size] + ) + + def test_default(self): + self.run_tester(BertModelTest.BertModelTester(self)) + + def test_config_to_json_string(self): + config = modeling.BertConfig(vocab_size=99, hidden_size=37) + obj = json.loads(config.to_json_string()) + self.assertEqual(obj["vocab_size"], 99) + self.assertEqual(obj["hidden_size"], 37) + + def run_tester(self, tester): + with self.test_session() as sess: + ops = tester.create_model() + init_op = tf.group( + tf.compat.v1.global_variables_initializer(), + tf.compat.v1.local_variables_initializer(), + ) + sess.run(init_op) + output_result = sess.run(ops) + tester.check_output(output_result) + + self.assert_all_tensors_reachable(sess, [init_op, ops]) + + @classmethod + def ids_tensor(cls, shape, vocab_size, rng=None, name=None): + """Creates a random int32 tensor of the shape within the vocab size.""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) + + def assert_all_tensors_reachable(self, sess, outputs): + """Checks that all the tensors in the graph are reachable from outputs.""" + graph = sess.graph + + ignore_strings = [ + "^.*/assert_less_equal/.*$", + "^.*/dilation_rate$", + "^.*/Tensordot/concat$", + "^.*/Tensordot/concat/axis$", + "^testing/.*$", + ] + + ignore_regexes = [re.compile(x) for x in ignore_strings] + + unreachable = self.get_unreachable_ops(graph, outputs) + filtered_unreachable = [] + for x in unreachable: + do_ignore = False + for r in ignore_regexes: + m = r.match(x.name) + if m is not None: + do_ignore = True + if do_ignore: + continue + filtered_unreachable.append(x) + unreachable = filtered_unreachable + + self.assertEqual( + len(unreachable), + 0, + "The following ops are unreachable: %s" + % (" ".join([x.name for x in unreachable])), + ) + + @classmethod + def get_unreachable_ops(cls, graph, outputs): + """Finds all of the tensors in graph that are unreachable from outputs.""" + outputs = cls.flatten_recursive(outputs) + output_to_op = collections.defaultdict(list) + op_to_all = collections.defaultdict(list) + assign_out_to_in = collections.defaultdict(list) + + for op in graph.get_operations(): + for x in op.inputs: + op_to_all[op.name].append(x.name) + for y in op.outputs: + output_to_op[y.name].append(op.name) + op_to_all[op.name].append(y.name) + if str(op.type) == "Assign": + for y in op.outputs: + for x in op.inputs: + assign_out_to_in[y.name].append(x.name) + + assign_groups = collections.defaultdict(list) + for out_name in assign_out_to_in.keys(): + name_group = assign_out_to_in[out_name] + for n1 in name_group: + assign_groups[n1].append(out_name) + for n2 in name_group: + if n1 != n2: + assign_groups[n1].append(n2) + + seen_tensors = {} + stack = [x.name for x in outputs] + while stack: + name = stack.pop() + if name in seen_tensors: + continue + seen_tensors[name] = True + + if name in output_to_op: + for op_name in output_to_op[name]: + if op_name in op_to_all: + for input_name in op_to_all[op_name]: + if input_name not in stack: + stack.append(input_name) + + expanded_names = [] + if name in assign_groups: + for assign_name in assign_groups[name]: + expanded_names.append(assign_name) + + for expanded_name in expanded_names: + if expanded_name not in stack: + stack.append(expanded_name) + + unreachable_ops = [] + for op in graph.get_operations(): + is_unreachable = False + all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] + for name in all_names: + if name not in seen_tensors: + is_unreachable = True + if is_unreachable: + unreachable_ops.append(op) + return unreachable_ops + + @classmethod + def flatten_recursive(cls, item): + """Flattens (potentially nested) a tuple/dictionary/list to a list.""" + output = [] + if isinstance(item, list): + output.extend(item) + elif isinstance(item, tuple): + output.extend(list(item)) + elif isinstance(item, dict): + for _, v in six.iteritems(item): + output.append(v) + else: + return [item] + + flat_output = [] + for x in output: + flat_output.extend(cls.flatten_recursive(x)) + return flat_output if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization.py index b1692e7b..0c23c8d8 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization.py @@ -23,152 +23,156 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): - """Creates an optimizer training op.""" - global_step = tf.compat.v1.train.get_or_create_global_step() - - learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) - - # Implements linear decay of the learning rate. - learning_rate = tf.compat.v1.train.polynomial_decay( - learning_rate, - global_step, - num_train_steps, - end_learning_rate=0.0, - power=1.0, - cycle=False) - - # Implements linear warmup. I.e., if global_step < num_warmup_steps, the - # learning rate will be `global_step/num_warmup_steps * init_lr`. - if num_warmup_steps: - global_steps_int = tf.cast(global_step, tf.int32) - warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) - - global_steps_float = tf.cast(global_steps_int, tf.float32) - warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) - - warmup_percent_done = global_steps_float / warmup_steps_float - warmup_learning_rate = init_lr * warmup_percent_done - - is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) - learning_rate = ( - (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) - - # It is recommended that you use this optimizer for fine tuning, since this - # is how the model was trained (note that the Adam m/v variables are NOT - # loaded from init_checkpoint.) - optimizer = AdamWeightDecayOptimizer( - learning_rate=learning_rate, - weight_decay_rate=0.01, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-6, - exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) - - if use_tpu: - optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) - - tvars = tf.compat.v1.trainable_variables() - grads = tf.gradients(ys=loss, xs=tvars) - - # This is how the model was pre-trained. - (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) - - train_op = optimizer.apply_gradients( - zip(grads, tvars), global_step=global_step) - - # Normally the global step update is done inside of `apply_gradients`. - # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use - # a different optimizer, you should probably take this line out. - new_global_step = global_step + 1 - train_op = tf.group(train_op, [global_step.assign(new_global_step)]) - return train_op + """Creates an optimizer training op.""" + global_step = tf.compat.v1.train.get_or_create_global_step() + + learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) + + # Implements linear decay of the learning rate. + learning_rate = tf.compat.v1.train.polynomial_decay( + learning_rate, + global_step, + num_train_steps, + end_learning_rate=0.0, + power=1.0, + cycle=False, + ) + + # Implements linear warmup. I.e., if global_step < num_warmup_steps, the + # learning rate will be `global_step/num_warmup_steps * init_lr`. + if num_warmup_steps: + global_steps_int = tf.cast(global_step, tf.int32) + warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) + + global_steps_float = tf.cast(global_steps_int, tf.float32) + warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) + + warmup_percent_done = global_steps_float / warmup_steps_float + warmup_learning_rate = init_lr * warmup_percent_done + + is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) + learning_rate = ( + 1.0 - is_warmup + ) * learning_rate + is_warmup * warmup_learning_rate + + # It is recommended that you use this optimizer for fine tuning, since this + # is how the model was trained (note that the Adam m/v variables are NOT + # loaded from init_checkpoint.) + optimizer = AdamWeightDecayOptimizer( + learning_rate=learning_rate, + weight_decay_rate=0.01, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], + ) + + if use_tpu: + optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) + + tvars = tf.compat.v1.trainable_variables() + grads = tf.gradients(ys=loss, xs=tvars) + + # This is how the model was pre-trained. + grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0) + + train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) + + # Normally the global step update is done inside of `apply_gradients`. + # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use + # a different optimizer, you should probably take this line out. + new_global_step = global_step + 1 + train_op = tf.group(train_op, [global_step.assign(new_global_step)]) + return train_op class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer): - """A basic Adam optimizer that includes "correct" L2 weight decay.""" - - def __init__(self, - learning_rate, - weight_decay_rate=0.0, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-6, - exclude_from_weight_decay=None, - name="AdamWeightDecayOptimizer"): - """Constructs a AdamWeightDecayOptimizer.""" - super(AdamWeightDecayOptimizer, self).__init__(False, name) - - self.learning_rate = learning_rate - self.weight_decay_rate = weight_decay_rate - self.beta_1 = beta_1 - self.beta_2 = beta_2 - self.epsilon = epsilon - self.exclude_from_weight_decay = exclude_from_weight_decay - - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """See base class.""" - assignments = [] - for (grad, param) in grads_and_vars: - if grad is None or param is None: - continue - - param_name = self._get_variable_name(param.name) - - m = tf.compat.v1.get_variable( - name=param_name + "/adam_m", - shape=param.shape.as_list(), - dtype=tf.float32, - trainable=False, - initializer=tf.compat.v1.zeros_initializer()) - v = tf.compat.v1.get_variable( - name=param_name + "/adam_v", - shape=param.shape.as_list(), - dtype=tf.float32, - trainable=False, - initializer=tf.compat.v1.zeros_initializer()) - - # Standard Adam update. - next_m = ( - tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) - next_v = ( - tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, - tf.square(grad))) - - update = next_m / (tf.sqrt(next_v) + self.epsilon) - - # Just adding the square of the weights to the loss function is *not* - # the correct way of using L2 regularization/weight decay with Adam, - # since that will interact with the m and v parameters in strange ways. - # - # Instead we want ot decay the weights in a manner that doesn't interact - # with the m/v parameters. This is equivalent to adding the square - # of the weights to the loss with plain (non-momentum) SGD. - if self._do_use_weight_decay(param_name): - update += self.weight_decay_rate * param - - update_with_lr = self.learning_rate * update - - next_param = param - update_with_lr - - assignments.extend( - [param.assign(next_param), - m.assign(next_m), - v.assign(next_v)]) - return tf.group(*assignments, name=name) - - def _do_use_weight_decay(self, param_name): - """Whether to use L2 weight decay for `param_name`.""" - if not self.weight_decay_rate: - return False - if self.exclude_from_weight_decay: - for r in self.exclude_from_weight_decay: - if re.search(r, param_name) is not None: - return False - return True - - def _get_variable_name(self, param_name): - """Get the variable name from the tensor name.""" - m = re.match("^(.*):\\d+$", param_name) - if m is not None: - param_name = m.group(1) - return param_name + """A basic Adam optimizer that includes "correct" L2 weight decay.""" + + def __init__( + self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + name="AdamWeightDecayOptimizer", + ): + """Constructs a AdamWeightDecayOptimizer.""" + super(AdamWeightDecayOptimizer, self).__init__(False, name) + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """See base class.""" + assignments = [] + for grad, param in grads_and_vars: + if grad is None or param is None: + continue + + param_name = self._get_variable_name(param.name) + + m = tf.compat.v1.get_variable( + name=param_name + "/adam_m", + shape=param.shape.as_list(), + dtype=tf.float32, + trainable=False, + initializer=tf.compat.v1.zeros_initializer(), + ) + v = tf.compat.v1.get_variable( + name=param_name + "/adam_v", + shape=param.shape.as_list(), + dtype=tf.float32, + trainable=False, + initializer=tf.compat.v1.zeros_initializer(), + ) + + # Standard Adam update. + next_m = tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad) + next_v = tf.multiply(self.beta_2, v) + tf.multiply( + 1.0 - self.beta_2, tf.square(grad) + ) + + update = next_m / (tf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want ot decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(param_name): + update += self.weight_decay_rate * param + + update_with_lr = self.learning_rate * update + + next_param = param - update_with_lr + + assignments.extend( + [param.assign(next_param), m.assign(next_m), v.assign(next_v)] + ) + return tf.group(*assignments, name=name) + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True + + def _get_variable_name(self, param_name): + """Get the variable name from the tensor name.""" + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization_test.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization_test.py index 781f9593..114a84a4 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization_test.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/optimization_test.py @@ -22,27 +22,30 @@ class OptimizationTest(tf.test.TestCase): - def test_adam(self): - with self.test_session() as sess: - w = tf.compat.v1.get_variable( - "w", - shape=[3], - initializer=tf.compat.v1.constant_initializer([0.1, -0.2, -0.1])) - x = tf.constant([0.4, 0.2, -0.5]) - loss = tf.reduce_mean(input_tensor=tf.square(x - w)) - tvars = tf.compat.v1.trainable_variables() - grads = tf.gradients(ys=loss, xs=tvars) - global_step = tf.compat.v1.train.get_or_create_global_step() - optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) - train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) - init_op = tf.group(tf.compat.v1.global_variables_initializer(), - tf.compat.v1.local_variables_initializer()) - sess.run(init_op) - for _ in range(100): - sess.run(train_op) - w_np = sess.run(w) - self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) + def test_adam(self): + with self.test_session() as sess: + w = tf.compat.v1.get_variable( + "w", + shape=[3], + initializer=tf.compat.v1.constant_initializer([0.1, -0.2, -0.1]), + ) + x = tf.constant([0.4, 0.2, -0.5]) + loss = tf.reduce_mean(input_tensor=tf.square(x - w)) + tvars = tf.compat.v1.trainable_variables() + grads = tf.gradients(ys=loss, xs=tvars) + global_step = tf.compat.v1.train.get_or_create_global_step() + optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) + train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) + init_op = tf.group( + tf.compat.v1.global_variables_initializer(), + tf.compat.v1.local_variables_initializer(), + ) + sess.run(init_op) + for _ in range(100): + sess.run(train_op) + w_np = sess.run(w) + self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier.py index 7c42cb76..5ddabd7f 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier.py @@ -32,48 +32,61 @@ ## Required parameters flags.DEFINE_string( - "data_dir", None, + "data_dir", + None, "The input data dir. Should contain the .tsv files (or other data files) " - "for the task.") + "for the task.", +) flags.DEFINE_string( - "bert_config_file", None, + "bert_config_file", + None, "The config json file corresponding to the pre-trained BERT model. " - "This specifies the model architecture.") + "This specifies the model architecture.", +) flags.DEFINE_string("task_name", None, "The name of the task to train.") -flags.DEFINE_string("vocab_file", None, - "The vocabulary file that the BERT model was trained on.") +flags.DEFINE_string( + "vocab_file", None, "The vocabulary file that the BERT model was trained on." +) flags.DEFINE_string( - "output_dir", None, - "The output directory where the model checkpoints will be written.") + "output_dir", + None, + "The output directory where the model checkpoints will be written.", +) ## Other parameters flags.DEFINE_string( - "init_checkpoint", None, - "Initial checkpoint (usually from a pre-trained BERT model).") + "init_checkpoint", + None, + "Initial checkpoint (usually from a pre-trained BERT model).", +) flags.DEFINE_bool( - "do_lower_case", True, + "do_lower_case", + True, "Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.") + "models and False for cased models.", +) flags.DEFINE_integer( - "max_seq_length", 128, + "max_seq_length", + 128, "The maximum total input sequence length after WordPiece tokenization. " "Sequences longer than this will be truncated, and sequences shorter " - "than this will be padded.") + "than this will be padded.", +) flags.DEFINE_bool("do_train", False, "Whether to run training.") flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") flags.DEFINE_bool( - "do_predict", False, - "Whether to run the model in inference mode on the test set.") + "do_predict", False, "Whether to run the model in inference mode on the test set." +) flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") @@ -83,899 +96,1000 @@ flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") -flags.DEFINE_float("num_train_epochs", 3.0, - "Total number of training epochs to perform.") +flags.DEFINE_float( + "num_train_epochs", 3.0, "Total number of training epochs to perform." +) flags.DEFINE_float( - "warmup_proportion", 0.1, + "warmup_proportion", + 0.1, "Proportion of training to perform linear learning rate warmup for. " - "E.g., 0.1 = 10% of training.") + "E.g., 0.1 = 10% of training.", +) -flags.DEFINE_integer("save_checkpoints_steps", 1000, - "How often to save the model checkpoint.") +flags.DEFINE_integer( + "save_checkpoints_steps", 1000, "How often to save the model checkpoint." +) -flags.DEFINE_integer("iterations_per_loop", 1000, - "How many steps to make in each estimator call.") +flags.DEFINE_integer( + "iterations_per_loop", 1000, "How many steps to make in each estimator call." +) flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") tf.flags.DEFINE_string( - "tpu_name", None, + "tpu_name", + None, "The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " - "url.") + "url.", +) tf.flags.DEFINE_string( - "tpu_zone", None, + "tpu_zone", + None, "[Optional] GCE zone where the Cloud TPU is located in. If not " "specified, we will attempt to automatically detect the GCE project from " - "metadata.") + "metadata.", +) tf.flags.DEFINE_string( - "gcp_project", None, + "gcp_project", + None, "[Optional] Project name for the Cloud TPU-enabled project. If not " "specified, we will attempt to automatically detect the GCE project from " - "metadata.") + "metadata.", +) tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") flags.DEFINE_integer( - "num_tpu_cores", 8, - "Only used if `use_tpu` is True. Total number of TPU cores to use.") + "num_tpu_cores", + 8, + "Only used if `use_tpu` is True. Total number of TPU cores to use.", +) class InputExample(object): - """A single training/test example for simple sequence classification.""" - - def __init__(self, guid, text_a, text_b=None, label=None): - """Constructs a InputExample. - - Args: - guid: Unique id for the example. - text_a: string. The untokenized text of the first sequence. For single - sequence tasks, only this sequence must be specified. - text_b: (Optional) string. The untokenized text of the second sequence. - Only must be specified for sequence pair tasks. - label: (Optional) string. The label of the example. This should be - specified for train and dev examples, but not for test examples. - """ - self.guid = guid - self.text_a = text_a - self.text_b = text_b - self.label = label + """A single training/test example for simple sequence classification.""" + + def __init__(self, guid, text_a, text_b=None, label=None): + """Constructs a InputExample. + + Args: + guid: Unique id for the example. + text_a: string. The untokenized text of the first sequence. For single + sequence tasks, only this sequence must be specified. + text_b: (Optional) string. The untokenized text of the second sequence. + Only must be specified for sequence pair tasks. + label: (Optional) string. The label of the example. This should be + specified for train and dev examples, but not for test examples. + """ + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.label = label class PaddingInputExample(object): - """Fake example so the num input examples is a multiple of the batch size. + """Fake example so the num input examples is a multiple of the batch size. - When running eval/predict on the TPU, we need to pad the number of examples - to be a multiple of the batch size, because the TPU requires a fixed batch - size. The alternative is to drop the last batch, which is bad because it means - the entire output data won't be generated. + When running eval/predict on the TPU, we need to pad the number of examples + to be a multiple of the batch size, because the TPU requires a fixed batch + size. The alternative is to drop the last batch, which is bad because it means + the entire output data won't be generated. - We use this class instead of `None` because treating `None` as padding - battches could cause silent errors. - """ + We use this class instead of `None` because treating `None` as padding + battches could cause silent errors. + """ class InputFeatures(object): - """A single set of features of data.""" + """A single set of features of data.""" - def __init__(self, - input_ids, - input_mask, - segment_ids, - label_id, - is_real_example=True): - self.input_ids = input_ids - self.input_mask = input_mask - self.segment_ids = segment_ids - self.label_id = label_id - self.is_real_example = is_real_example + def __init__( + self, input_ids, input_mask, segment_ids, label_id, is_real_example=True + ): + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.label_id = label_id + self.is_real_example = is_real_example class DataProcessor(object): - """Base class for data converters for sequence classification data sets.""" + """Base class for data converters for sequence classification data sets.""" - def get_train_examples(self, data_dir): - """Gets a collection of `InputExample`s for the train set.""" - raise NotImplementedError() + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + raise NotImplementedError() - def get_dev_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" - raise NotImplementedError() + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() - def get_test_examples(self, data_dir): - """Gets a collection of `InputExample`s for prediction.""" - raise NotImplementedError() + def get_test_examples(self, data_dir): + """Gets a collection of `InputExample`s for prediction.""" + raise NotImplementedError() - def get_labels(self): - """Gets the list of labels for this data set.""" - raise NotImplementedError() + def get_labels(self): + """Gets the list of labels for this data set.""" + raise NotImplementedError() - @classmethod - def _read_tsv(cls, input_file, quotechar=None): - """Reads a tab separated value file.""" - with tf.io.gfile.GFile(input_file, "r") as f: - reader = csv.reader(f, delimiter="\t", quotechar=quotechar) - lines = [] - for line in reader: - lines.append(line) - return lines + @classmethod + def _read_tsv(cls, input_file, quotechar=None): + """Reads a tab separated value file.""" + with tf.io.gfile.GFile(input_file, "r") as f: + reader = csv.reader(f, delimiter="\t", quotechar=quotechar) + lines = [] + for line in reader: + lines.append(line) + return lines class XnliProcessor(DataProcessor): - """Processor for the XNLI data set.""" - - def __init__(self): - self.language = "zh" - - def get_train_examples(self, data_dir): - """See base class.""" - lines = self._read_tsv( - os.path.join(data_dir, "multinli", - "multinli.train.%s.tsv" % self.language)) - examples = [] - for (i, line) in enumerate(lines): - if i == 0: - continue - guid = "train-%d" % (i) - text_a = tokenization.convert_to_unicode(line[0]) - text_b = tokenization.convert_to_unicode(line[1]) - label = tokenization.convert_to_unicode(line[2]) - if label == tokenization.convert_to_unicode("contradictory"): - label = tokenization.convert_to_unicode("contradiction") - examples.append( - InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) - return examples - - def get_dev_examples(self, data_dir): - """See base class.""" - lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) - examples = [] - for (i, line) in enumerate(lines): - if i == 0: - continue - guid = "dev-%d" % (i) - language = tokenization.convert_to_unicode(line[0]) - if language != tokenization.convert_to_unicode(self.language): - continue - text_a = tokenization.convert_to_unicode(line[6]) - text_b = tokenization.convert_to_unicode(line[7]) - label = tokenization.convert_to_unicode(line[1]) - examples.append( - InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) - return examples - - def get_labels(self): - """See base class.""" - return ["contradiction", "entailment", "neutral"] + """Processor for the XNLI data set.""" + + def __init__(self): + self.language = "zh" + + def get_train_examples(self, data_dir): + """See base class.""" + lines = self._read_tsv( + os.path.join(data_dir, "multinli", "multinli.train.%s.tsv" % self.language) + ) + examples = [] + for i, line in enumerate(lines): + if i == 0: + continue + guid = "train-%d" % (i) + text_a = tokenization.convert_to_unicode(line[0]) + text_b = tokenization.convert_to_unicode(line[1]) + label = tokenization.convert_to_unicode(line[2]) + if label == tokenization.convert_to_unicode("contradictory"): + label = tokenization.convert_to_unicode("contradiction") + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) + ) + return examples + + def get_dev_examples(self, data_dir): + """See base class.""" + lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) + examples = [] + for i, line in enumerate(lines): + if i == 0: + continue + guid = "dev-%d" % (i) + language = tokenization.convert_to_unicode(line[0]) + if language != tokenization.convert_to_unicode(self.language): + continue + text_a = tokenization.convert_to_unicode(line[6]) + text_b = tokenization.convert_to_unicode(line[7]) + label = tokenization.convert_to_unicode(line[1]) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) + ) + return examples + + def get_labels(self): + """See base class.""" + return ["contradiction", "entailment", "neutral"] class MnliProcessor(DataProcessor): - """Processor for the MultiNLI data set (GLUE version).""" - - def get_train_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), - "dev_matched") - - def get_test_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") - - def get_labels(self): - """See base class.""" - return ["contradiction", "entailment", "neutral"] - - def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" - examples = [] - for (i, line) in enumerate(lines): - if i == 0: - continue - guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) - text_a = tokenization.convert_to_unicode(line[8]) - text_b = tokenization.convert_to_unicode(line[9]) - if set_type == "test": - label = "contradiction" - else: - label = tokenization.convert_to_unicode(line[-1]) - examples.append( - InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) - return examples + """Processor for the MultiNLI data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train" + ) + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched" + ) + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test" + ) + + def get_labels(self): + """See base class.""" + return ["contradiction", "entailment", "neutral"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for i, line in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) + text_a = tokenization.convert_to_unicode(line[8]) + text_b = tokenization.convert_to_unicode(line[9]) + if set_type == "test": + label = "contradiction" + else: + label = tokenization.convert_to_unicode(line[-1]) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) + ) + return examples class MrpcProcessor(DataProcessor): - """Processor for the MRPC data set (GLUE version).""" - - def get_train_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") - - def get_test_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") - - def get_labels(self): - """See base class.""" - return ["0", "1"] - - def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" - examples = [] - for (i, line) in enumerate(lines): - if i == 0: - continue - guid = "%s-%s" % (set_type, i) - text_a = tokenization.convert_to_unicode(line[3]) - text_b = tokenization.convert_to_unicode(line[4]) - if set_type == "test": - label = "0" - else: - label = tokenization.convert_to_unicode(line[0]) - examples.append( - InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) - return examples + """Processor for the MRPC data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train" + ) + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev" + ) + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "test.tsv")), "test" + ) + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for i, line in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, i) + text_a = tokenization.convert_to_unicode(line[3]) + text_b = tokenization.convert_to_unicode(line[4]) + if set_type == "test": + label = "0" + else: + label = tokenization.convert_to_unicode(line[0]) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) + ) + return examples class ColaProcessor(DataProcessor): - """Processor for the CoLA data set (GLUE version).""" - - def get_train_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") - - def get_test_examples(self, data_dir): - """See base class.""" - return self._create_examples( - self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") - - def get_labels(self): - """See base class.""" - return ["0", "1"] - - def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" - examples = [] - for (i, line) in enumerate(lines): - # Only the test set has a header - if set_type == "test" and i == 0: - continue - guid = "%s-%s" % (set_type, i) - if set_type == "test": - text_a = tokenization.convert_to_unicode(line[1]) - label = "0" - else: - text_a = tokenization.convert_to_unicode(line[3]) - label = tokenization.convert_to_unicode(line[1]) - examples.append( - InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) - return examples - - -def convert_single_example(ex_index, example, label_list, max_seq_length, - tokenizer): - """Converts a single `InputExample` into a single `InputFeatures`.""" - - if isinstance(example, PaddingInputExample): - return InputFeatures( - input_ids=[0] * max_seq_length, - input_mask=[0] * max_seq_length, - segment_ids=[0] * max_seq_length, - label_id=0, - is_real_example=False) - - label_map = {} - for (i, label) in enumerate(label_list): - label_map[label] = i - - tokens_a = tokenizer.tokenize(example.text_a) - tokens_b = None - if example.text_b: - tokens_b = tokenizer.tokenize(example.text_b) - - if tokens_b: - # Modifies `tokens_a` and `tokens_b` in place so that the total - # length is less than the specified length. - # Account for [CLS], [SEP], [SEP] with "- 3" - _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) - else: - # Account for [CLS] and [SEP] with "- 2" - if len(tokens_a) > max_seq_length - 2: - tokens_a = tokens_a[0:(max_seq_length - 2)] - - # The convention in BERT is: - # (a) For sequence pairs: - # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] - # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 - # (b) For single sequences: - # tokens: [CLS] the dog is hairy . [SEP] - # type_ids: 0 0 0 0 0 0 0 - # - # Where "type_ids" are used to indicate whether this is the first - # sequence or the second sequence. The embedding vectors for `type=0` and - # `type=1` were learned during pre-training and are added to the wordpiece - # embedding vector (and position vector). This is not *strictly* necessary - # since the [SEP] token unambiguously separates the sequences, but it makes - # it easier for the model to learn the concept of sequences. - # - # For classification tasks, the first vector (corresponding to [CLS]) is - # used as the "sentence vector". Note that this only makes sense because - # the entire model is fine-tuned. - tokens = [] - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in tokens_a: - tokens.append(token) + """Processor for the CoLA data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train" + ) + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev" + ) + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "test.tsv")), "test" + ) + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for i, line in enumerate(lines): + # Only the test set has a header + if set_type == "test" and i == 0: + continue + guid = "%s-%s" % (set_type, i) + if set_type == "test": + text_a = tokenization.convert_to_unicode(line[1]) + label = "0" + else: + text_a = tokenization.convert_to_unicode(line[3]) + label = tokenization.convert_to_unicode(line[1]) + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=None, label=label) + ) + return examples + + +def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer): + """Converts a single `InputExample` into a single `InputFeatures`.""" + + if isinstance(example, PaddingInputExample): + return InputFeatures( + input_ids=[0] * max_seq_length, + input_mask=[0] * max_seq_length, + segment_ids=[0] * max_seq_length, + label_id=0, + is_real_example=False, + ) + + label_map = {} + for i, label in enumerate(label_list): + label_map[label] = i + + tokens_a = tokenizer.tokenize(example.text_a) + tokens_b = None + if example.text_b: + tokens_b = tokenizer.tokenize(example.text_b) + + if tokens_b: + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP] with "- 3" + _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) + else: + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > max_seq_length - 2: + tokens_a = tokens_a[0 : (max_seq_length - 2)] + + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + segment_ids = [] + tokens.append("[CLS]") segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) - - if tokens_b: - for token in tokens_b: - tokens.append(token) - segment_ids.append(1) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) tokens.append("[SEP]") - segment_ids.append(1) - - input_ids = tokenizer.convert_tokens_to_ids(tokens) - - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1] * len(input_ids) - - # Zero-pad up to the sequence length. - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) segment_ids.append(0) - assert len(input_ids) == max_seq_length - assert len(input_mask) == max_seq_length - assert len(segment_ids) == max_seq_length - - label_id = label_map[example.label] - if ex_index < 5: - tf.compat.v1.logging.info("*** Example ***") - tf.compat.v1.logging.info("guid: %s" % (example.guid)) - tf.compat.v1.logging.info("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in tokens])) - tf.compat.v1.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - tf.compat.v1.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - tf.compat.v1.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) - tf.compat.v1.logging.info("label: %s (id = %d)" % (example.label, label_id)) - - feature = InputFeatures( - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - label_id=label_id, - is_real_example=True) - return feature + if tokens_b: + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + label_id = label_map[example.label] + if ex_index < 5: + tf.compat.v1.logging.info("*** Example ***") + tf.compat.v1.logging.info("guid: %s" % (example.guid)) + tf.compat.v1.logging.info( + "tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens]) + ) + tf.compat.v1.logging.info( + "input_ids: %s" % " ".join([str(x) for x in input_ids]) + ) + tf.compat.v1.logging.info( + "input_mask: %s" % " ".join([str(x) for x in input_mask]) + ) + tf.compat.v1.logging.info( + "segment_ids: %s" % " ".join([str(x) for x in segment_ids]) + ) + tf.compat.v1.logging.info("label: %s (id = %d)" % (example.label, label_id)) + + feature = InputFeatures( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + label_id=label_id, + is_real_example=True, + ) + return feature def file_based_convert_examples_to_features( - examples, label_list, max_seq_length, tokenizer, output_file): - """Convert a set of `InputExample`s to a TFRecord file.""" - - writer = tf.io.TFRecordWriter(output_file) - - for (ex_index, example) in enumerate(examples): - if ex_index % 10000 == 0: - tf.compat.v1.logging.info("Writing example %d of %d" % (ex_index, len(examples))) - - feature = convert_single_example(ex_index, example, label_list, - max_seq_length, tokenizer) - - def create_int_feature(values): - f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) - return f - - features = collections.OrderedDict() - features["input_ids"] = create_int_feature(feature.input_ids) - features["input_mask"] = create_int_feature(feature.input_mask) - features["segment_ids"] = create_int_feature(feature.segment_ids) - features["label_ids"] = create_int_feature([feature.label_id]) - features["is_real_example"] = create_int_feature( - [int(feature.is_real_example)]) - - tf_example = tf.train.Example(features=tf.train.Features(feature=features)) - writer.write(tf_example.SerializeToString()) - writer.close() - - -def file_based_input_fn_builder(input_file, seq_length, is_training, - drop_remainder): - """Creates an `input_fn` closure to be passed to TPUEstimator.""" - - name_to_features = { - "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), - "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64), - "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), - "label_ids": tf.io.FixedLenFeature([], tf.int64), - "is_real_example": tf.io.FixedLenFeature([], tf.int64), - } - - def _decode_record(record, name_to_features): - """Decodes a record to a TensorFlow example.""" - example = tf.io.parse_single_example(serialized=record, features=name_to_features) - - # tf.Example only supports tf.int64, but the TPU only supports tf.int32. - # So cast all int64 to int32. - for name in list(example.keys()): - t = example[name] - if t.dtype == tf.int64: - t = tf.cast(t, dtype=tf.int32) - example[name] = t - - return example - - def input_fn(params): - """The actual input function.""" - batch_size = params["batch_size"] - - # For training, we want a lot of parallel reading and shuffling. - # For eval, we want no shuffling and parallel reading doesn't matter. - d = tf.data.TFRecordDataset(input_file) - if is_training: - d = d.repeat() - d = d.shuffle(buffer_size=100) - - d = d.apply( - tf.data.experimental.map_and_batch( - lambda record: _decode_record(record, name_to_features), - batch_size=batch_size, - drop_remainder=drop_remainder)) - - return d - - return input_fn + examples, label_list, max_seq_length, tokenizer, output_file +): + """Convert a set of `InputExample`s to a TFRecord file.""" + + writer = tf.io.TFRecordWriter(output_file) + + for ex_index, example in enumerate(examples): + if ex_index % 10000 == 0: + tf.compat.v1.logging.info( + "Writing example %d of %d" % (ex_index, len(examples)) + ) + + feature = convert_single_example( + ex_index, example, label_list, max_seq_length, tokenizer + ) + + def create_int_feature(values): + f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) + return f + + features = collections.OrderedDict() + features["input_ids"] = create_int_feature(feature.input_ids) + features["input_mask"] = create_int_feature(feature.input_mask) + features["segment_ids"] = create_int_feature(feature.segment_ids) + features["label_ids"] = create_int_feature([feature.label_id]) + features["is_real_example"] = create_int_feature([int(feature.is_real_example)]) + + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + writer.write(tf_example.SerializeToString()) + writer.close() + + +def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): + """Creates an `input_fn` closure to be passed to TPUEstimator.""" + + name_to_features = { + "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), + "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64), + "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), + "label_ids": tf.io.FixedLenFeature([], tf.int64), + "is_real_example": tf.io.FixedLenFeature([], tf.int64), + } + + def _decode_record(record, name_to_features): + """Decodes a record to a TensorFlow example.""" + example = tf.io.parse_single_example( + serialized=record, features=name_to_features + ) + + # tf.Example only supports tf.int64, but the TPU only supports tf.int32. + # So cast all int64 to int32. + for name in list(example.keys()): + t = example[name] + if t.dtype == tf.int64: + t = tf.cast(t, dtype=tf.int32) + example[name] = t + + return example + + def input_fn(params): + """The actual input function.""" + batch_size = params["batch_size"] + + # For training, we want a lot of parallel reading and shuffling. + # For eval, we want no shuffling and parallel reading doesn't matter. + d = tf.data.TFRecordDataset(input_file) + if is_training: + d = d.repeat() + d = d.shuffle(buffer_size=100) + + d = d.apply( + tf.data.experimental.map_and_batch( + lambda record: _decode_record(record, name_to_features), + batch_size=batch_size, + drop_remainder=drop_remainder, + ) + ) + + return d + + return input_fn def _truncate_seq_pair(tokens_a, tokens_b, max_length): - """Truncates a sequence pair in place to the maximum length.""" - - # This is a simple heuristic which will always truncate the longer sequence - # one token at a time. This makes more sense than truncating an equal percent - # of tokens from each, since if one sequence is very short then each token - # that's truncated likely contains more information than a longer sequence. - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_length: - break - if len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() - - -def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, - labels, num_labels, use_one_hot_embeddings): - """Creates a classification model.""" - model = modeling.BertModel( - config=bert_config, - is_training=is_training, - input_ids=input_ids, - input_mask=input_mask, - token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) - - # In the demo, we are doing a simple classification task on the entire - # segment. - # - # If you want to use the token-level output, use model.get_sequence_output() - # instead. - output_layer = model.get_pooled_output() - - hidden_size = output_layer.shape[-1].value - - output_weights = tf.compat.v1.get_variable( - "output_weights", [num_labels, hidden_size], - initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)) - - output_bias = tf.compat.v1.get_variable( - "output_bias", [num_labels], initializer=tf.compat.v1.zeros_initializer()) - - with tf.compat.v1.variable_scope("loss"): - if is_training: - # I.e., 0.1 dropout - output_layer = tf.nn.dropout(output_layer, rate=1 - (0.9)) - - logits = tf.matmul(output_layer, output_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - probabilities = tf.nn.softmax(logits, axis=-1) - log_probs = tf.nn.log_softmax(logits, axis=-1) - - one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) - - per_example_loss = -tf.reduce_sum(input_tensor=one_hot_labels * log_probs, axis=-1) - loss = tf.reduce_mean(input_tensor=per_example_loss) - - return (loss, per_example_loss, logits, probabilities) - - -def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, - num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings): - """Returns `model_fn` closure for TPUEstimator.""" - - def model_fn(features, labels, mode, params): # pylint: disable=unused-argument - """The `model_fn` for TPUEstimator.""" - - tf.compat.v1.logging.info("*** Features ***") - for name in sorted(features.keys()): - tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) - - input_ids = features["input_ids"] - input_mask = features["input_mask"] - segment_ids = features["segment_ids"] - label_ids = features["label_ids"] - is_real_example = None - if "is_real_example" in features: - is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) - else: - is_real_example = tf.ones(tf.shape(input=label_ids), dtype=tf.float32) - - is_training = (mode == tf.estimator.ModeKeys.TRAIN) - - (total_loss, per_example_loss, logits, probabilities) = create_model( - bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, - num_labels, use_one_hot_embeddings) - - tvars = tf.compat.v1.trainable_variables() - initialized_variable_names = {} - scaffold_fn = None - if init_checkpoint: - (assignment_map, initialized_variable_names - ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) - if use_tpu: - - def tpu_scaffold(): - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - return tf.compat.v1.train.Scaffold() - - scaffold_fn = tpu_scaffold - else: - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - - tf.compat.v1.logging.info("**** Trainable Variables ****") - for var in tvars: - init_string = "" - if var.name in initialized_variable_names: - init_string = ", *INIT_FROM_CKPT*" - tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, - init_string) - - output_spec = None - if mode == tf.estimator.ModeKeys.TRAIN: - - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) - - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - train_op=train_op, - scaffold_fn=scaffold_fn) - elif mode == tf.estimator.ModeKeys.EVAL: - - def metric_fn(per_example_loss, label_ids, logits, is_real_example): - predictions = tf.argmax(input=logits, axis=-1, output_type=tf.int32) - accuracy = tf.compat.v1.metrics.accuracy( - labels=label_ids, predictions=predictions, weights=is_real_example) - loss = tf.compat.v1.metrics.mean(values=per_example_loss, weights=is_real_example) - return { - "eval_accuracy": accuracy, - "eval_loss": loss, - } - - eval_metrics = (metric_fn, - [per_example_loss, label_ids, logits, is_real_example]) - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - eval_metrics=eval_metrics, - scaffold_fn=scaffold_fn) - else: - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - predictions={"probabilities": probabilities}, - scaffold_fn=scaffold_fn) - return output_spec - - return model_fn + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + +def create_model( + bert_config, + is_training, + input_ids, + input_mask, + segment_ids, + labels, + num_labels, + use_one_hot_embeddings, +): + """Creates a classification model.""" + model = modeling.BertModel( + config=bert_config, + is_training=is_training, + input_ids=input_ids, + input_mask=input_mask, + token_type_ids=segment_ids, + use_one_hot_embeddings=use_one_hot_embeddings, + ) + + # In the demo, we are doing a simple classification task on the entire + # segment. + # + # If you want to use the token-level output, use model.get_sequence_output() + # instead. + output_layer = model.get_pooled_output() + + hidden_size = output_layer.shape[-1].value + + output_weights = tf.compat.v1.get_variable( + "output_weights", + [num_labels, hidden_size], + initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02), + ) + + output_bias = tf.compat.v1.get_variable( + "output_bias", [num_labels], initializer=tf.compat.v1.zeros_initializer() + ) + + with tf.compat.v1.variable_scope("loss"): + if is_training: + # I.e., 0.1 dropout + output_layer = tf.nn.dropout(output_layer, rate=1 - (0.9)) + + logits = tf.matmul(output_layer, output_weights, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + probabilities = tf.nn.softmax(logits, axis=-1) + log_probs = tf.nn.log_softmax(logits, axis=-1) + + one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) + + per_example_loss = -tf.reduce_sum( + input_tensor=one_hot_labels * log_probs, axis=-1 + ) + loss = tf.reduce_mean(input_tensor=per_example_loss) + + return (loss, per_example_loss, logits, probabilities) + + +def model_fn_builder( + bert_config, + num_labels, + init_checkpoint, + learning_rate, + num_train_steps, + num_warmup_steps, + use_tpu, + use_one_hot_embeddings, +): + """Returns `model_fn` closure for TPUEstimator.""" + + def model_fn(features, labels, mode, params): # pylint: disable=unused-argument + """The `model_fn` for TPUEstimator.""" + + tf.compat.v1.logging.info("*** Features ***") + for name in sorted(features.keys()): + tf.compat.v1.logging.info( + " name = %s, shape = %s" % (name, features[name].shape) + ) + + input_ids = features["input_ids"] + input_mask = features["input_mask"] + segment_ids = features["segment_ids"] + label_ids = features["label_ids"] + is_real_example = None + if "is_real_example" in features: + is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) + else: + is_real_example = tf.ones(tf.shape(input=label_ids), dtype=tf.float32) + + is_training = mode == tf.estimator.ModeKeys.TRAIN + + total_loss, per_example_loss, logits, probabilities = create_model( + bert_config, + is_training, + input_ids, + input_mask, + segment_ids, + label_ids, + num_labels, + use_one_hot_embeddings, + ) + + tvars = tf.compat.v1.trainable_variables() + initialized_variable_names = {} + scaffold_fn = None + if init_checkpoint: + assignment_map, initialized_variable_names = ( + modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) + ) + if use_tpu: + + def tpu_scaffold(): + tf.compat.v1.train.init_from_checkpoint( + init_checkpoint, assignment_map + ) + return tf.compat.v1.train.Scaffold() + + scaffold_fn = tpu_scaffold + else: + tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) + + tf.compat.v1.logging.info("**** Trainable Variables ****") + for var in tvars: + init_string = "" + if var.name in initialized_variable_names: + init_string = ", *INIT_FROM_CKPT*" + tf.compat.v1.logging.info( + " name = %s, shape = %s%s", var.name, var.shape, init_string + ) + + output_spec = None + if mode == tf.estimator.ModeKeys.TRAIN: + + train_op = optimization.create_optimizer( + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu + ) + + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn + ) + elif mode == tf.estimator.ModeKeys.EVAL: + + def metric_fn(per_example_loss, label_ids, logits, is_real_example): + predictions = tf.argmax(input=logits, axis=-1, output_type=tf.int32) + accuracy = tf.compat.v1.metrics.accuracy( + labels=label_ids, predictions=predictions, weights=is_real_example + ) + loss = tf.compat.v1.metrics.mean( + values=per_example_loss, weights=is_real_example + ) + return { + "eval_accuracy": accuracy, + "eval_loss": loss, + } + + eval_metrics = ( + metric_fn, + [per_example_loss, label_ids, logits, is_real_example], + ) + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, + loss=total_loss, + eval_metrics=eval_metrics, + scaffold_fn=scaffold_fn, + ) + else: + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, + predictions={"probabilities": probabilities}, + scaffold_fn=scaffold_fn, + ) + return output_spec + + return model_fn # This function is not used by this file but is still used by the Colab and # people who depend on it. def input_fn_builder(features, seq_length, is_training, drop_remainder): - """Creates an `input_fn` closure to be passed to TPUEstimator.""" - - all_input_ids = [] - all_input_mask = [] - all_segment_ids = [] - all_label_ids = [] - - for feature in features: - all_input_ids.append(feature.input_ids) - all_input_mask.append(feature.input_mask) - all_segment_ids.append(feature.segment_ids) - all_label_ids.append(feature.label_id) - - def input_fn(params): - """The actual input function.""" - batch_size = params["batch_size"] - - num_examples = len(features) - - # This is for demo purposes and does NOT scale to large data sets. We do - # not use Dataset.from_generator() because that uses tf.py_func which is - # not TPU compatible. The right way to load data is with TFRecordReader. - d = tf.data.Dataset.from_tensor_slices({ - "input_ids": - tf.constant( - all_input_ids, shape=[num_examples, seq_length], - dtype=tf.int32), - "input_mask": - tf.constant( - all_input_mask, - shape=[num_examples, seq_length], - dtype=tf.int32), - "segment_ids": - tf.constant( - all_segment_ids, - shape=[num_examples, seq_length], - dtype=tf.int32), - "label_ids": - tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), - }) - - if is_training: - d = d.repeat() - d = d.shuffle(buffer_size=100) - - d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) - return d - - return input_fn + """Creates an `input_fn` closure to be passed to TPUEstimator.""" + + all_input_ids = [] + all_input_mask = [] + all_segment_ids = [] + all_label_ids = [] + + for feature in features: + all_input_ids.append(feature.input_ids) + all_input_mask.append(feature.input_mask) + all_segment_ids.append(feature.segment_ids) + all_label_ids.append(feature.label_id) + + def input_fn(params): + """The actual input function.""" + batch_size = params["batch_size"] + + num_examples = len(features) + + # This is for demo purposes and does NOT scale to large data sets. We do + # not use Dataset.from_generator() because that uses tf.py_func which is + # not TPU compatible. The right way to load data is with TFRecordReader. + d = tf.data.Dataset.from_tensor_slices( + { + "input_ids": tf.constant( + all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32 + ), + "input_mask": tf.constant( + all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32 + ), + "segment_ids": tf.constant( + all_segment_ids, shape=[num_examples, seq_length], dtype=tf.int32 + ), + "label_ids": tf.constant( + all_label_ids, shape=[num_examples], dtype=tf.int32 + ), + } + ) + + if is_training: + d = d.repeat() + d = d.shuffle(buffer_size=100) + + d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) + return d + + return input_fn # This function is not used by this file but is still used by the Colab and # people who depend on it. -def convert_examples_to_features(examples, label_list, max_seq_length, - tokenizer): - """Convert a set of `InputExample`s to a list of `InputFeatures`.""" +def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): + """Convert a set of `InputExample`s to a list of `InputFeatures`.""" - features = [] - for (ex_index, example) in enumerate(examples): - if ex_index % 10000 == 0: - tf.compat.v1.logging.info("Writing example %d of %d" % (ex_index, len(examples))) + features = [] + for ex_index, example in enumerate(examples): + if ex_index % 10000 == 0: + tf.compat.v1.logging.info( + "Writing example %d of %d" % (ex_index, len(examples)) + ) - feature = convert_single_example(ex_index, example, label_list, - max_seq_length, tokenizer) + feature = convert_single_example( + ex_index, example, label_list, max_seq_length, tokenizer + ) - features.append(feature) - return features + features.append(feature) + return features def main(_): - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - - processors = { - "cola": ColaProcessor, - "mnli": MnliProcessor, - "mrpc": MrpcProcessor, - "xnli": XnliProcessor, - } - - tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, - FLAGS.init_checkpoint) - - if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: - raise ValueError( - "At least one of `do_train`, `do_eval` or `do_predict' must be True.") - - bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) - - if FLAGS.max_seq_length > bert_config.max_position_embeddings: - raise ValueError( - "Cannot use sequence length %d because the BERT model " - "was only trained up to sequence length %d" % - (FLAGS.max_seq_length, bert_config.max_position_embeddings)) - - tf.io.gfile.makedirs(FLAGS.output_dir) - - task_name = FLAGS.task_name.lower() - - if task_name not in processors: - raise ValueError("Task not found: %s" % (task_name)) - - processor = processors[task_name]() - - label_list = processor.get_labels() - - tokenizer = tokenization.FullTokenizer( - vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.compat.v1.estimator.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, - model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - - train_examples = None - num_train_steps = None - num_warmup_steps = None - if FLAGS.do_train: - train_examples = processor.get_train_examples(FLAGS.data_dir) - num_train_steps = int( - len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) - num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) - - model_fn = model_fn_builder( - bert_config=bert_config, - num_labels=len(label_list), - init_checkpoint=FLAGS.init_checkpoint, - learning_rate=FLAGS.learning_rate, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps, - use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = tf.compat.v1.estimator.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, - model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size, - predict_batch_size=FLAGS.predict_batch_size) - - if FLAGS.do_train: - train_file = os.path.join(FLAGS.output_dir, "train.tf_record") - file_based_convert_examples_to_features( - train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) - tf.compat.v1.logging.info("***** Running training *****") - tf.compat.v1.logging.info(" Num examples = %d", len(train_examples)) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) - tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) - train_input_fn = file_based_input_fn_builder( - input_file=train_file, - seq_length=FLAGS.max_seq_length, - is_training=True, - drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) - - if FLAGS.do_eval: - eval_examples = processor.get_dev_examples(FLAGS.data_dir) - num_actual_eval_examples = len(eval_examples) - if FLAGS.use_tpu: - # TPU requires a fixed batch size for all batches, therefore the number - # of examples must be a multiple of the batch size, or else examples - # will get dropped. So we pad with fake examples which are ignored - # later on. These do NOT count towards the metric (all tf.metrics - # support a per-instance weight, and these get a weight of 0.0). - while len(eval_examples) % FLAGS.eval_batch_size != 0: - eval_examples.append(PaddingInputExample()) - - eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") - file_based_convert_examples_to_features( - eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) - - tf.compat.v1.logging.info("***** Running evaluation *****") - tf.compat.v1.logging.info(" Num examples = %d (%d actual, %d padding)", - len(eval_examples), num_actual_eval_examples, - len(eval_examples) - num_actual_eval_examples) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) - - # This tells the estimator to run through the entire set. - eval_steps = None - # However, if running eval on the TPU, you will need to specify the - # number of steps. - if FLAGS.use_tpu: - assert len(eval_examples) % FLAGS.eval_batch_size == 0 - eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) - - eval_drop_remainder = True if FLAGS.use_tpu else False - eval_input_fn = file_based_input_fn_builder( - input_file=eval_file, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=eval_drop_remainder) - - result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) - - output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") - with tf.io.gfile.GFile(output_eval_file, "w") as writer: - tf.compat.v1.logging.info("***** Eval results *****") - for key in sorted(result.keys()): - tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) - - if FLAGS.do_predict: - predict_examples = processor.get_test_examples(FLAGS.data_dir) - num_actual_predict_examples = len(predict_examples) - if FLAGS.use_tpu: - # TPU requires a fixed batch size for all batches, therefore the number - # of examples must be a multiple of the batch size, or else examples - # will get dropped. So we pad with fake examples which are ignored - # later on. - while len(predict_examples) % FLAGS.predict_batch_size != 0: - predict_examples.append(PaddingInputExample()) - - predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") - file_based_convert_examples_to_features(predict_examples, label_list, - FLAGS.max_seq_length, tokenizer, - predict_file) - - tf.compat.v1.logging.info("***** Running prediction*****") - tf.compat.v1.logging.info(" Num examples = %d (%d actual, %d padding)", - len(predict_examples), num_actual_predict_examples, - len(predict_examples) - num_actual_predict_examples) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size) - - predict_drop_remainder = True if FLAGS.use_tpu else False - predict_input_fn = file_based_input_fn_builder( - input_file=predict_file, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=predict_drop_remainder) - - result = estimator.predict(input_fn=predict_input_fn) - - output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") - with tf.io.gfile.GFile(output_predict_file, "w") as writer: - num_written_lines = 0 - tf.compat.v1.logging.info("***** Predict results *****") - for (i, prediction) in enumerate(result): - probabilities = prediction["probabilities"] - if i >= num_actual_predict_examples: - break - output_line = "\t".join( - str(class_probability) - for class_probability in probabilities) + "\n" - writer.write(output_line) - num_written_lines += 1 - assert num_written_lines == num_actual_predict_examples + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + processors = { + "cola": ColaProcessor, + "mnli": MnliProcessor, + "mrpc": MrpcProcessor, + "xnli": XnliProcessor, + } + + tokenization.validate_case_matches_checkpoint( + FLAGS.do_lower_case, FLAGS.init_checkpoint + ) + + if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: + raise ValueError( + "At least one of `do_train`, `do_eval` or `do_predict' must be True." + ) + + bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) + + if FLAGS.max_seq_length > bert_config.max_position_embeddings: + raise ValueError( + "Cannot use sequence length %d because the BERT model " + "was only trained up to sequence length %d" + % (FLAGS.max_seq_length, bert_config.max_position_embeddings) + ) + + tf.io.gfile.makedirs(FLAGS.output_dir) + + task_name = FLAGS.task_name.lower() + + if task_name not in processors: + raise ValueError("Task not found: %s" % (task_name)) + + processor = processors[task_name]() + + label_list = processor.get_labels() + + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case + ) + + tpu_cluster_resolver = None + if FLAGS.use_tpu and FLAGS.tpu_name: + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project + ) + + is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 + run_config = tf.compat.v1.estimator.tpu.RunConfig( + cluster=tpu_cluster_resolver, + master=FLAGS.master, + model_dir=FLAGS.output_dir, + save_checkpoints_steps=FLAGS.save_checkpoints_steps, + tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_tpu_cores, + per_host_input_for_training=is_per_host, + ), + ) + + train_examples = None + num_train_steps = None + num_warmup_steps = None + if FLAGS.do_train: + train_examples = processor.get_train_examples(FLAGS.data_dir) + num_train_steps = int( + len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs + ) + num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) + + model_fn = model_fn_builder( + bert_config=bert_config, + num_labels=len(label_list), + init_checkpoint=FLAGS.init_checkpoint, + learning_rate=FLAGS.learning_rate, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps, + use_tpu=FLAGS.use_tpu, + use_one_hot_embeddings=FLAGS.use_tpu, + ) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = tf.compat.v1.estimator.tpu.TPUEstimator( + use_tpu=FLAGS.use_tpu, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS.train_batch_size, + eval_batch_size=FLAGS.eval_batch_size, + predict_batch_size=FLAGS.predict_batch_size, + ) + + if FLAGS.do_train: + train_file = os.path.join(FLAGS.output_dir, "train.tf_record") + file_based_convert_examples_to_features( + train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file + ) + tf.compat.v1.logging.info("***** Running training *****") + tf.compat.v1.logging.info(" Num examples = %d", len(train_examples)) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) + tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) + train_input_fn = file_based_input_fn_builder( + input_file=train_file, + seq_length=FLAGS.max_seq_length, + is_training=True, + drop_remainder=True, + ) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + + if FLAGS.do_eval: + eval_examples = processor.get_dev_examples(FLAGS.data_dir) + num_actual_eval_examples = len(eval_examples) + if FLAGS.use_tpu: + # TPU requires a fixed batch size for all batches, therefore the number + # of examples must be a multiple of the batch size, or else examples + # will get dropped. So we pad with fake examples which are ignored + # later on. These do NOT count towards the metric (all tf.metrics + # support a per-instance weight, and these get a weight of 0.0). + while len(eval_examples) % FLAGS.eval_batch_size != 0: + eval_examples.append(PaddingInputExample()) + + eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") + file_based_convert_examples_to_features( + eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file + ) + + tf.compat.v1.logging.info("***** Running evaluation *****") + tf.compat.v1.logging.info( + " Num examples = %d (%d actual, %d padding)", + len(eval_examples), + num_actual_eval_examples, + len(eval_examples) - num_actual_eval_examples, + ) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) + + # This tells the estimator to run through the entire set. + eval_steps = None + # However, if running eval on the TPU, you will need to specify the + # number of steps. + if FLAGS.use_tpu: + assert len(eval_examples) % FLAGS.eval_batch_size == 0 + eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) + + eval_drop_remainder = True if FLAGS.use_tpu else False + eval_input_fn = file_based_input_fn_builder( + input_file=eval_file, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=eval_drop_remainder, + ) + + result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) + + output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") + with tf.io.gfile.GFile(output_eval_file, "w") as writer: + tf.compat.v1.logging.info("***** Eval results *****") + for key in sorted(result.keys()): + tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + if FLAGS.do_predict: + predict_examples = processor.get_test_examples(FLAGS.data_dir) + num_actual_predict_examples = len(predict_examples) + if FLAGS.use_tpu: + # TPU requires a fixed batch size for all batches, therefore the number + # of examples must be a multiple of the batch size, or else examples + # will get dropped. So we pad with fake examples which are ignored + # later on. + while len(predict_examples) % FLAGS.predict_batch_size != 0: + predict_examples.append(PaddingInputExample()) + + predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") + file_based_convert_examples_to_features( + predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file + ) + + tf.compat.v1.logging.info("***** Running prediction*****") + tf.compat.v1.logging.info( + " Num examples = %d (%d actual, %d padding)", + len(predict_examples), + num_actual_predict_examples, + len(predict_examples) - num_actual_predict_examples, + ) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size) + + predict_drop_remainder = True if FLAGS.use_tpu else False + predict_input_fn = file_based_input_fn_builder( + input_file=predict_file, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=predict_drop_remainder, + ) + + result = estimator.predict(input_fn=predict_input_fn) + + output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") + with tf.io.gfile.GFile(output_predict_file, "w") as writer: + num_written_lines = 0 + tf.compat.v1.logging.info("***** Predict results *****") + for i, prediction in enumerate(result): + probabilities = prediction["probabilities"] + if i >= num_actual_predict_examples: + break + output_line = ( + "\t".join( + str(class_probability) for class_probability in probabilities + ) + + "\n" + ) + writer.write(output_line) + num_written_lines += 1 + assert num_written_lines == num_actual_predict_examples if __name__ == "__main__": - flags.mark_flag_as_required("data_dir") - flags.mark_flag_as_required("task_name") - flags.mark_flag_as_required("vocab_file") - flags.mark_flag_as_required("bert_config_file") - flags.mark_flag_as_required("output_dir") - tf.compat.v1.app.run() + flags.mark_flag_as_required("data_dir") + flags.mark_flag_as_required("task_name") + flags.mark_flag_as_required("vocab_file") + flags.mark_flag_as_required("bert_config_file") + flags.mark_flag_as_required("output_dir") + tf.compat.v1.app.run() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier_with_tfhub.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier_with_tfhub.py index 66290f5e..b8eff987 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier_with_tfhub.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_classifier_with_tfhub.py @@ -30,285 +30,324 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "bert_hub_module_handle", None, - "Handle for the BERT TF-Hub module.") - - -def create_model(is_training, input_ids, input_mask, segment_ids, labels, - num_labels, bert_hub_module_handle): - """Creates a classification model.""" - tags = set() - if is_training: - tags.add("train") - bert_module = hub.Module(bert_hub_module_handle, tags=tags, trainable=True) - bert_inputs = dict( - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids) - bert_outputs = bert_module( - inputs=bert_inputs, - signature="tokens", - as_dict=True) - - # In the demo, we are doing a simple classification task on the entire - # segment. - # - # If you want to use the token-level output, use - # bert_outputs["sequence_output"] instead. - output_layer = bert_outputs["pooled_output"] - - hidden_size = output_layer.shape[-1].value - - output_weights = tf.compat.v1.get_variable( - "output_weights", [num_labels, hidden_size], - initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)) - - output_bias = tf.compat.v1.get_variable( - "output_bias", [num_labels], initializer=tf.compat.v1.zeros_initializer()) - - with tf.compat.v1.variable_scope("loss"): + "bert_hub_module_handle", None, "Handle for the BERT TF-Hub module." +) + + +def create_model( + is_training, + input_ids, + input_mask, + segment_ids, + labels, + num_labels, + bert_hub_module_handle, +): + """Creates a classification model.""" + tags = set() if is_training: - # I.e., 0.1 dropout - output_layer = tf.nn.dropout(output_layer, rate=1 - (0.9)) - - logits = tf.matmul(output_layer, output_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - probabilities = tf.nn.softmax(logits, axis=-1) - log_probs = tf.nn.log_softmax(logits, axis=-1) - - one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) - - per_example_loss = -tf.reduce_sum(input_tensor=one_hot_labels * log_probs, axis=-1) - loss = tf.reduce_mean(input_tensor=per_example_loss) - - return (loss, per_example_loss, logits, probabilities) - - -def model_fn_builder(num_labels, learning_rate, num_train_steps, - num_warmup_steps, use_tpu, bert_hub_module_handle): - """Returns `model_fn` closure for TPUEstimator.""" - - def model_fn(features, labels, mode, params): # pylint: disable=unused-argument - """The `model_fn` for TPUEstimator.""" - - tf.compat.v1.logging.info("*** Features ***") - for name in sorted(features.keys()): - tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) - - input_ids = features["input_ids"] - input_mask = features["input_mask"] - segment_ids = features["segment_ids"] - label_ids = features["label_ids"] - - is_training = (mode == tf.estimator.ModeKeys.TRAIN) - - (total_loss, per_example_loss, logits, probabilities) = create_model( - is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, - bert_hub_module_handle) - - output_spec = None - if mode == tf.estimator.ModeKeys.TRAIN: - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) - - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - train_op=train_op) - elif mode == tf.estimator.ModeKeys.EVAL: - - def metric_fn(per_example_loss, label_ids, logits): - predictions = tf.argmax(input=logits, axis=-1, output_type=tf.int32) - accuracy = tf.compat.v1.metrics.accuracy(label_ids, predictions) - loss = tf.compat.v1.metrics.mean(per_example_loss) - return { - "eval_accuracy": accuracy, - "eval_loss": loss, - } - - eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - eval_metrics=eval_metrics) - elif mode == tf.estimator.ModeKeys.PREDICT: - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, predictions={"probabilities": probabilities}) - else: - raise ValueError( - "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode)) - - return output_spec - - return model_fn + tags.add("train") + bert_module = hub.Module(bert_hub_module_handle, tags=tags, trainable=True) + bert_inputs = dict( + input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids + ) + bert_outputs = bert_module(inputs=bert_inputs, signature="tokens", as_dict=True) + + # In the demo, we are doing a simple classification task on the entire + # segment. + # + # If you want to use the token-level output, use + # bert_outputs["sequence_output"] instead. + output_layer = bert_outputs["pooled_output"] + + hidden_size = output_layer.shape[-1].value + + output_weights = tf.compat.v1.get_variable( + "output_weights", + [num_labels, hidden_size], + initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02), + ) + + output_bias = tf.compat.v1.get_variable( + "output_bias", [num_labels], initializer=tf.compat.v1.zeros_initializer() + ) + + with tf.compat.v1.variable_scope("loss"): + if is_training: + # I.e., 0.1 dropout + output_layer = tf.nn.dropout(output_layer, rate=1 - (0.9)) + + logits = tf.matmul(output_layer, output_weights, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + probabilities = tf.nn.softmax(logits, axis=-1) + log_probs = tf.nn.log_softmax(logits, axis=-1) + + one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) + + per_example_loss = -tf.reduce_sum( + input_tensor=one_hot_labels * log_probs, axis=-1 + ) + loss = tf.reduce_mean(input_tensor=per_example_loss) + + return (loss, per_example_loss, logits, probabilities) + + +def model_fn_builder( + num_labels, + learning_rate, + num_train_steps, + num_warmup_steps, + use_tpu, + bert_hub_module_handle, +): + """Returns `model_fn` closure for TPUEstimator.""" + + def model_fn(features, labels, mode, params): # pylint: disable=unused-argument + """The `model_fn` for TPUEstimator.""" + + tf.compat.v1.logging.info("*** Features ***") + for name in sorted(features.keys()): + tf.compat.v1.logging.info( + " name = %s, shape = %s" % (name, features[name].shape) + ) + + input_ids = features["input_ids"] + input_mask = features["input_mask"] + segment_ids = features["segment_ids"] + label_ids = features["label_ids"] + + is_training = mode == tf.estimator.ModeKeys.TRAIN + + total_loss, per_example_loss, logits, probabilities = create_model( + is_training, + input_ids, + input_mask, + segment_ids, + label_ids, + num_labels, + bert_hub_module_handle, + ) + + output_spec = None + if mode == tf.estimator.ModeKeys.TRAIN: + train_op = optimization.create_optimizer( + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu + ) + + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, loss=total_loss, train_op=train_op + ) + elif mode == tf.estimator.ModeKeys.EVAL: + + def metric_fn(per_example_loss, label_ids, logits): + predictions = tf.argmax(input=logits, axis=-1, output_type=tf.int32) + accuracy = tf.compat.v1.metrics.accuracy(label_ids, predictions) + loss = tf.compat.v1.metrics.mean(per_example_loss) + return { + "eval_accuracy": accuracy, + "eval_loss": loss, + } + + eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, loss=total_loss, eval_metrics=eval_metrics + ) + elif mode == tf.estimator.ModeKeys.PREDICT: + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, predictions={"probabilities": probabilities} + ) + else: + raise ValueError( + "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode) + ) + + return output_spec + + return model_fn def create_tokenizer_from_hub_module(bert_hub_module_handle): - """Get the vocab file and casing info from the Hub module.""" - with tf.Graph().as_default(): - bert_module = hub.Module(bert_hub_module_handle) - tokenization_info = bert_module(signature="tokenization_info", as_dict=True) - with tf.compat.v1.Session() as sess: - vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], - tokenization_info["do_lower_case"]]) - return tokenization.FullTokenizer( - vocab_file=vocab_file, do_lower_case=do_lower_case) + """Get the vocab file and casing info from the Hub module.""" + with tf.Graph().as_default(): + bert_module = hub.Module(bert_hub_module_handle) + tokenization_info = bert_module(signature="tokenization_info", as_dict=True) + with tf.compat.v1.Session() as sess: + vocab_file, do_lower_case = sess.run( + [tokenization_info["vocab_file"], tokenization_info["do_lower_case"]] + ) + return tokenization.FullTokenizer( + vocab_file=vocab_file, do_lower_case=do_lower_case + ) def main(_): - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - - processors = { - "cola": run_classifier.ColaProcessor, - "mnli": run_classifier.MnliProcessor, - "mrpc": run_classifier.MrpcProcessor, - } - - if not FLAGS.do_train and not FLAGS.do_eval: - raise ValueError("At least one of `do_train` or `do_eval` must be True.") - - tf.io.gfile.makedirs(FLAGS.output_dir) - - task_name = FLAGS.task_name.lower() - - if task_name not in processors: - raise ValueError("Task not found: %s" % (task_name)) - - processor = processors[task_name]() - - label_list = processor.get_labels() - - tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle) - - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.compat.v1.estimator.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, - model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - - train_examples = None - num_train_steps = None - num_warmup_steps = None - if FLAGS.do_train: - train_examples = processor.get_train_examples(FLAGS.data_dir) - num_train_steps = int( - len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) - num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) - - model_fn = model_fn_builder( - num_labels=len(label_list), - learning_rate=FLAGS.learning_rate, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps, - use_tpu=FLAGS.use_tpu, - bert_hub_module_handle=FLAGS.bert_hub_module_handle) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = tf.compat.v1.estimator.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, - model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size, - predict_batch_size=FLAGS.predict_batch_size) - - if FLAGS.do_train: - train_features = run_classifier.convert_examples_to_features( - train_examples, label_list, FLAGS.max_seq_length, tokenizer) - tf.compat.v1.logging.info("***** Running training *****") - tf.compat.v1.logging.info(" Num examples = %d", len(train_examples)) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) - tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) - train_input_fn = run_classifier.input_fn_builder( - features=train_features, - seq_length=FLAGS.max_seq_length, - is_training=True, - drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) - - if FLAGS.do_eval: - eval_examples = processor.get_dev_examples(FLAGS.data_dir) - eval_features = run_classifier.convert_examples_to_features( - eval_examples, label_list, FLAGS.max_seq_length, tokenizer) - - tf.compat.v1.logging.info("***** Running evaluation *****") - tf.compat.v1.logging.info(" Num examples = %d", len(eval_examples)) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) - - # This tells the estimator to run through the entire set. - eval_steps = None - # However, if running eval on the TPU, you will need to specify the - # number of steps. - if FLAGS.use_tpu: - # Eval will be slightly WRONG on the TPU because it will truncate - # the last batch. - eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) - - eval_drop_remainder = True if FLAGS.use_tpu else False - eval_input_fn = run_classifier.input_fn_builder( - features=eval_features, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=eval_drop_remainder) - - result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) - - output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") - with tf.io.gfile.GFile(output_eval_file, "w") as writer: - tf.compat.v1.logging.info("***** Eval results *****") - for key in sorted(result.keys()): - tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) - - if FLAGS.do_predict: - predict_examples = processor.get_test_examples(FLAGS.data_dir) - if FLAGS.use_tpu: - # Discard batch remainder if running on TPU - n = len(predict_examples) - predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)] - - predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") - run_classifier.file_based_convert_examples_to_features( - predict_examples, label_list, FLAGS.max_seq_length, tokenizer, - predict_file) - - tf.compat.v1.logging.info("***** Running prediction*****") - tf.compat.v1.logging.info(" Num examples = %d", len(predict_examples)) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size) - - predict_input_fn = run_classifier.file_based_input_fn_builder( - input_file=predict_file, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=FLAGS.use_tpu) - - result = estimator.predict(input_fn=predict_input_fn) - - output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") - with tf.io.gfile.GFile(output_predict_file, "w") as writer: - tf.compat.v1.logging.info("***** Predict results *****") - for prediction in result: - probabilities = prediction["probabilities"] - output_line = "\t".join( - str(class_probability) - for class_probability in probabilities) + "\n" - writer.write(output_line) + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + processors = { + "cola": run_classifier.ColaProcessor, + "mnli": run_classifier.MnliProcessor, + "mrpc": run_classifier.MrpcProcessor, + } + + if not FLAGS.do_train and not FLAGS.do_eval: + raise ValueError("At least one of `do_train` or `do_eval` must be True.") + + tf.io.gfile.makedirs(FLAGS.output_dir) + + task_name = FLAGS.task_name.lower() + + if task_name not in processors: + raise ValueError("Task not found: %s" % (task_name)) + + processor = processors[task_name]() + + label_list = processor.get_labels() + + tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle) + + tpu_cluster_resolver = None + if FLAGS.use_tpu and FLAGS.tpu_name: + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project + ) + + is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 + run_config = tf.compat.v1.estimator.tpu.RunConfig( + cluster=tpu_cluster_resolver, + master=FLAGS.master, + model_dir=FLAGS.output_dir, + save_checkpoints_steps=FLAGS.save_checkpoints_steps, + tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_tpu_cores, + per_host_input_for_training=is_per_host, + ), + ) + + train_examples = None + num_train_steps = None + num_warmup_steps = None + if FLAGS.do_train: + train_examples = processor.get_train_examples(FLAGS.data_dir) + num_train_steps = int( + len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs + ) + num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) + + model_fn = model_fn_builder( + num_labels=len(label_list), + learning_rate=FLAGS.learning_rate, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps, + use_tpu=FLAGS.use_tpu, + bert_hub_module_handle=FLAGS.bert_hub_module_handle, + ) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = tf.compat.v1.estimator.tpu.TPUEstimator( + use_tpu=FLAGS.use_tpu, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS.train_batch_size, + eval_batch_size=FLAGS.eval_batch_size, + predict_batch_size=FLAGS.predict_batch_size, + ) + + if FLAGS.do_train: + train_features = run_classifier.convert_examples_to_features( + train_examples, label_list, FLAGS.max_seq_length, tokenizer + ) + tf.compat.v1.logging.info("***** Running training *****") + tf.compat.v1.logging.info(" Num examples = %d", len(train_examples)) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) + tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) + train_input_fn = run_classifier.input_fn_builder( + features=train_features, + seq_length=FLAGS.max_seq_length, + is_training=True, + drop_remainder=True, + ) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + + if FLAGS.do_eval: + eval_examples = processor.get_dev_examples(FLAGS.data_dir) + eval_features = run_classifier.convert_examples_to_features( + eval_examples, label_list, FLAGS.max_seq_length, tokenizer + ) + + tf.compat.v1.logging.info("***** Running evaluation *****") + tf.compat.v1.logging.info(" Num examples = %d", len(eval_examples)) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) + + # This tells the estimator to run through the entire set. + eval_steps = None + # However, if running eval on the TPU, you will need to specify the + # number of steps. + if FLAGS.use_tpu: + # Eval will be slightly WRONG on the TPU because it will truncate + # the last batch. + eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) + + eval_drop_remainder = True if FLAGS.use_tpu else False + eval_input_fn = run_classifier.input_fn_builder( + features=eval_features, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=eval_drop_remainder, + ) + + result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) + + output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") + with tf.io.gfile.GFile(output_eval_file, "w") as writer: + tf.compat.v1.logging.info("***** Eval results *****") + for key in sorted(result.keys()): + tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + if FLAGS.do_predict: + predict_examples = processor.get_test_examples(FLAGS.data_dir) + if FLAGS.use_tpu: + # Discard batch remainder if running on TPU + n = len(predict_examples) + predict_examples = predict_examples[: (n - n % FLAGS.predict_batch_size)] + + predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") + run_classifier.file_based_convert_examples_to_features( + predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file + ) + + tf.compat.v1.logging.info("***** Running prediction*****") + tf.compat.v1.logging.info(" Num examples = %d", len(predict_examples)) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size) + + predict_input_fn = run_classifier.file_based_input_fn_builder( + input_file=predict_file, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=FLAGS.use_tpu, + ) + + result = estimator.predict(input_fn=predict_input_fn) + + output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") + with tf.io.gfile.GFile(output_predict_file, "w") as writer: + tf.compat.v1.logging.info("***** Predict results *****") + for prediction in result: + probabilities = prediction["probabilities"] + output_line = ( + "\t".join( + str(class_probability) for class_probability in probabilities + ) + + "\n" + ) + writer.write(output_line) if __name__ == "__main__": - flags.mark_flag_as_required("data_dir") - flags.mark_flag_as_required("task_name") - flags.mark_flag_as_required("bert_hub_module_handle") - flags.mark_flag_as_required("output_dir") - tf.compat.v1.app.run() + flags.mark_flag_as_required("data_dir") + flags.mark_flag_as_required("task_name") + flags.mark_flag_as_required("bert_hub_module_handle") + flags.mark_flag_as_required("output_dir") + tf.compat.v1.app.run() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_pretraining.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_pretraining.py index 357cffb8..23fc8c7c 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_pretraining.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_pretraining.py @@ -29,33 +29,43 @@ ## Required parameters flags.DEFINE_string( - "bert_config_file", None, + "bert_config_file", + None, "The config json file corresponding to the pre-trained BERT model. " - "This specifies the model architecture.") + "This specifies the model architecture.", +) flags.DEFINE_string( - "input_file", None, - "Input TF example files (can be a glob or comma separated).") + "input_file", None, "Input TF example files (can be a glob or comma separated)." +) flags.DEFINE_string( - "output_dir", None, - "The output directory where the model checkpoints will be written.") + "output_dir", + None, + "The output directory where the model checkpoints will be written.", +) ## Other parameters flags.DEFINE_string( - "init_checkpoint", None, - "Initial checkpoint (usually from a pre-trained BERT model).") + "init_checkpoint", + None, + "Initial checkpoint (usually from a pre-trained BERT model).", +) flags.DEFINE_integer( - "max_seq_length", 128, + "max_seq_length", + 128, "The maximum total input sequence length after WordPiece tokenization. " "Sequences longer than this will be truncated, and sequences shorter " - "than this will be padded. Must match data generation.") + "than this will be padded. Must match data generation.", +) flags.DEFINE_integer( - "max_predictions_per_seq", 20, + "max_predictions_per_seq", + 20, "Maximum number of masked LM predictions per sequence. " - "Must match data generation.") + "Must match data generation.", +) flags.DEFINE_bool("do_train", False, "Whether to run training.") @@ -71,423 +81,491 @@ flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") -flags.DEFINE_integer("save_checkpoints_steps", 1000, - "How often to save the model checkpoint.") +flags.DEFINE_integer( + "save_checkpoints_steps", 1000, "How often to save the model checkpoint." +) -flags.DEFINE_integer("iterations_per_loop", 1000, - "How many steps to make in each estimator call.") +flags.DEFINE_integer( + "iterations_per_loop", 1000, "How many steps to make in each estimator call." +) flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") tf.flags.DEFINE_string( - "tpu_name", None, + "tpu_name", + None, "The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " - "url.") + "url.", +) tf.flags.DEFINE_string( - "tpu_zone", None, + "tpu_zone", + None, "[Optional] GCE zone where the Cloud TPU is located in. If not " "specified, we will attempt to automatically detect the GCE project from " - "metadata.") + "metadata.", +) tf.flags.DEFINE_string( - "gcp_project", None, + "gcp_project", + None, "[Optional] Project name for the Cloud TPU-enabled project. If not " "specified, we will attempt to automatically detect the GCE project from " - "metadata.") + "metadata.", +) tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") flags.DEFINE_integer( - "num_tpu_cores", 8, - "Only used if `use_tpu` is True. Total number of TPU cores to use.") - - -def model_fn_builder(bert_config, init_checkpoint, learning_rate, - num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings): - """Returns `model_fn` closure for TPUEstimator.""" - - def model_fn(features, labels, mode, params): # pylint: disable=unused-argument - """The `model_fn` for TPUEstimator.""" - - tf.compat.v1.logging.info("*** Features ***") - for name in sorted(features.keys()): - tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) - - input_ids = features["input_ids"] - input_mask = features["input_mask"] - segment_ids = features["segment_ids"] - masked_lm_positions = features["masked_lm_positions"] - masked_lm_ids = features["masked_lm_ids"] - masked_lm_weights = features["masked_lm_weights"] - next_sentence_labels = features["next_sentence_labels"] - - is_training = (mode == tf.estimator.ModeKeys.TRAIN) - - model = modeling.BertModel( - config=bert_config, - is_training=is_training, - input_ids=input_ids, - input_mask=input_mask, - token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) - - (masked_lm_loss, - masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( - bert_config, model.get_sequence_output(), model.get_embedding_table(), - masked_lm_positions, masked_lm_ids, masked_lm_weights) - - (next_sentence_loss, next_sentence_example_loss, - next_sentence_log_probs) = get_next_sentence_output( - bert_config, model.get_pooled_output(), next_sentence_labels) - - total_loss = masked_lm_loss + next_sentence_loss - - tvars = tf.compat.v1.trainable_variables() - - initialized_variable_names = {} - scaffold_fn = None - if init_checkpoint: - (assignment_map, initialized_variable_names - ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) - if use_tpu: - - def tpu_scaffold(): - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - return tf.compat.v1.train.Scaffold() - - scaffold_fn = tpu_scaffold - else: - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - - tf.compat.v1.logging.info("**** Trainable Variables ****") - for var in tvars: - init_string = "" - if var.name in initialized_variable_names: - init_string = ", *INIT_FROM_CKPT*" - tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, - init_string) - - output_spec = None - if mode == tf.estimator.ModeKeys.TRAIN: - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) - - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - train_op=train_op, - scaffold_fn=scaffold_fn) - elif mode == tf.estimator.ModeKeys.EVAL: - - def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, - masked_lm_weights, next_sentence_example_loss, - next_sentence_log_probs, next_sentence_labels): - """Computes the loss and accuracy of the model.""" - masked_lm_log_probs = tf.reshape(masked_lm_log_probs, - [-1, masked_lm_log_probs.shape[-1]]) - masked_lm_predictions = tf.argmax( - input=masked_lm_log_probs, axis=-1, output_type=tf.int32) - masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) - masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) - masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) - masked_lm_accuracy = tf.compat.v1.metrics.accuracy( - labels=masked_lm_ids, - predictions=masked_lm_predictions, - weights=masked_lm_weights) - masked_lm_mean_loss = tf.compat.v1.metrics.mean( - values=masked_lm_example_loss, weights=masked_lm_weights) - - next_sentence_log_probs = tf.reshape( - next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) - next_sentence_predictions = tf.argmax( - input=next_sentence_log_probs, axis=-1, output_type=tf.int32) - next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) - next_sentence_accuracy = tf.compat.v1.metrics.accuracy( - labels=next_sentence_labels, predictions=next_sentence_predictions) - next_sentence_mean_loss = tf.compat.v1.metrics.mean( - values=next_sentence_example_loss) - - return { - "masked_lm_accuracy": masked_lm_accuracy, - "masked_lm_loss": masked_lm_mean_loss, - "next_sentence_accuracy": next_sentence_accuracy, - "next_sentence_loss": next_sentence_mean_loss, - } + "num_tpu_cores", + 8, + "Only used if `use_tpu` is True. Total number of TPU cores to use.", +) + + +def model_fn_builder( + bert_config, + init_checkpoint, + learning_rate, + num_train_steps, + num_warmup_steps, + use_tpu, + use_one_hot_embeddings, +): + """Returns `model_fn` closure for TPUEstimator.""" + + def model_fn(features, labels, mode, params): # pylint: disable=unused-argument + """The `model_fn` for TPUEstimator.""" + + tf.compat.v1.logging.info("*** Features ***") + for name in sorted(features.keys()): + tf.compat.v1.logging.info( + " name = %s, shape = %s" % (name, features[name].shape) + ) + + input_ids = features["input_ids"] + input_mask = features["input_mask"] + segment_ids = features["segment_ids"] + masked_lm_positions = features["masked_lm_positions"] + masked_lm_ids = features["masked_lm_ids"] + masked_lm_weights = features["masked_lm_weights"] + next_sentence_labels = features["next_sentence_labels"] + + is_training = mode == tf.estimator.ModeKeys.TRAIN + + model = modeling.BertModel( + config=bert_config, + is_training=is_training, + input_ids=input_ids, + input_mask=input_mask, + token_type_ids=segment_ids, + use_one_hot_embeddings=use_one_hot_embeddings, + ) + + masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs = ( + get_masked_lm_output( + bert_config, + model.get_sequence_output(), + model.get_embedding_table(), + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + ) + ) + + next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs = ( + get_next_sentence_output( + bert_config, model.get_pooled_output(), next_sentence_labels + ) + ) + + total_loss = masked_lm_loss + next_sentence_loss + + tvars = tf.compat.v1.trainable_variables() + + initialized_variable_names = {} + scaffold_fn = None + if init_checkpoint: + assignment_map, initialized_variable_names = ( + modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) + ) + if use_tpu: + + def tpu_scaffold(): + tf.compat.v1.train.init_from_checkpoint( + init_checkpoint, assignment_map + ) + return tf.compat.v1.train.Scaffold() + + scaffold_fn = tpu_scaffold + else: + tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) + + tf.compat.v1.logging.info("**** Trainable Variables ****") + for var in tvars: + init_string = "" + if var.name in initialized_variable_names: + init_string = ", *INIT_FROM_CKPT*" + tf.compat.v1.logging.info( + " name = %s, shape = %s%s", var.name, var.shape, init_string + ) + + output_spec = None + if mode == tf.estimator.ModeKeys.TRAIN: + train_op = optimization.create_optimizer( + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu + ) + + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn + ) + elif mode == tf.estimator.ModeKeys.EVAL: + + def metric_fn( + masked_lm_example_loss, + masked_lm_log_probs, + masked_lm_ids, + masked_lm_weights, + next_sentence_example_loss, + next_sentence_log_probs, + next_sentence_labels, + ): + """Computes the loss and accuracy of the model.""" + masked_lm_log_probs = tf.reshape( + masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]] + ) + masked_lm_predictions = tf.argmax( + input=masked_lm_log_probs, axis=-1, output_type=tf.int32 + ) + masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) + masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) + masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) + masked_lm_accuracy = tf.compat.v1.metrics.accuracy( + labels=masked_lm_ids, + predictions=masked_lm_predictions, + weights=masked_lm_weights, + ) + masked_lm_mean_loss = tf.compat.v1.metrics.mean( + values=masked_lm_example_loss, weights=masked_lm_weights + ) + + next_sentence_log_probs = tf.reshape( + next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]] + ) + next_sentence_predictions = tf.argmax( + input=next_sentence_log_probs, axis=-1, output_type=tf.int32 + ) + next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) + next_sentence_accuracy = tf.compat.v1.metrics.accuracy( + labels=next_sentence_labels, predictions=next_sentence_predictions + ) + next_sentence_mean_loss = tf.compat.v1.metrics.mean( + values=next_sentence_example_loss + ) + + return { + "masked_lm_accuracy": masked_lm_accuracy, + "masked_lm_loss": masked_lm_mean_loss, + "next_sentence_accuracy": next_sentence_accuracy, + "next_sentence_loss": next_sentence_mean_loss, + } + + eval_metrics = ( + metric_fn, + [ + masked_lm_example_loss, + masked_lm_log_probs, + masked_lm_ids, + masked_lm_weights, + next_sentence_example_loss, + next_sentence_log_probs, + next_sentence_labels, + ], + ) + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, + loss=total_loss, + eval_metrics=eval_metrics, + scaffold_fn=scaffold_fn, + ) + else: + raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) + + return output_spec + + return model_fn + + +def get_masked_lm_output( + bert_config, input_tensor, output_weights, positions, label_ids, label_weights +): + """Get loss and log probs for the masked LM.""" + input_tensor = gather_indexes(input_tensor, positions) + + with tf.compat.v1.variable_scope("cls/predictions"): + # We apply one more non-linear transformation before the output layer. + # This matrix is not used after pre-training. + with tf.compat.v1.variable_scope("transform"): + input_tensor = tf.compat.v1.layers.dense( + input_tensor, + units=bert_config.hidden_size, + activation=modeling.get_activation(bert_config.hidden_act), + kernel_initializer=modeling.create_initializer( + bert_config.initializer_range + ), + ) + input_tensor = modeling.layer_norm(input_tensor) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + output_bias = tf.compat.v1.get_variable( + "output_bias", + shape=[bert_config.vocab_size], + initializer=tf.compat.v1.zeros_initializer(), + ) + logits = tf.matmul(input_tensor, output_weights, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + log_probs = tf.nn.log_softmax(logits, axis=-1) + + label_ids = tf.reshape(label_ids, [-1]) + label_weights = tf.reshape(label_weights, [-1]) + + one_hot_labels = tf.one_hot( + label_ids, depth=bert_config.vocab_size, dtype=tf.float32 + ) + + # The `positions` tensor might be zero-padded (if the sequence is too + # short to have the maximum number of predictions). The `label_weights` + # tensor has a value of 1.0 for every real prediction and 0.0 for the + # padding predictions. + per_example_loss = -tf.reduce_sum( + input_tensor=log_probs * one_hot_labels, axis=[-1] + ) + numerator = tf.reduce_sum(input_tensor=label_weights * per_example_loss) + denominator = tf.reduce_sum(input_tensor=label_weights) + 1e-5 + loss = numerator / denominator - eval_metrics = (metric_fn, [ - masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, - masked_lm_weights, next_sentence_example_loss, - next_sentence_log_probs, next_sentence_labels - ]) - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - eval_metrics=eval_metrics, - scaffold_fn=scaffold_fn) - else: - raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) - - return output_spec - - return model_fn - - -def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, - label_ids, label_weights): - """Get loss and log probs for the masked LM.""" - input_tensor = gather_indexes(input_tensor, positions) - - with tf.compat.v1.variable_scope("cls/predictions"): - # We apply one more non-linear transformation before the output layer. - # This matrix is not used after pre-training. - with tf.compat.v1.variable_scope("transform"): - input_tensor = tf.compat.v1.layers.dense( - input_tensor, - units=bert_config.hidden_size, - activation=modeling.get_activation(bert_config.hidden_act), - kernel_initializer=modeling.create_initializer( - bert_config.initializer_range)) - input_tensor = modeling.layer_norm(input_tensor) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - output_bias = tf.compat.v1.get_variable( - "output_bias", - shape=[bert_config.vocab_size], - initializer=tf.compat.v1.zeros_initializer()) - logits = tf.matmul(input_tensor, output_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - log_probs = tf.nn.log_softmax(logits, axis=-1) - - label_ids = tf.reshape(label_ids, [-1]) - label_weights = tf.reshape(label_weights, [-1]) - - one_hot_labels = tf.one_hot( - label_ids, depth=bert_config.vocab_size, dtype=tf.float32) - - # The `positions` tensor might be zero-padded (if the sequence is too - # short to have the maximum number of predictions). The `label_weights` - # tensor has a value of 1.0 for every real prediction and 0.0 for the - # padding predictions. - per_example_loss = -tf.reduce_sum(input_tensor=log_probs * one_hot_labels, axis=[-1]) - numerator = tf.reduce_sum(input_tensor=label_weights * per_example_loss) - denominator = tf.reduce_sum(input_tensor=label_weights) + 1e-5 - loss = numerator / denominator - - return (loss, per_example_loss, log_probs) + return (loss, per_example_loss, log_probs) def get_next_sentence_output(bert_config, input_tensor, labels): - """Get loss and log probs for the next sentence prediction.""" - - # Simple binary classification. Note that 0 is "next sentence" and 1 is - # "random sentence". This weight matrix is not used after pre-training. - with tf.compat.v1.variable_scope("cls/seq_relationship"): - output_weights = tf.compat.v1.get_variable( - "output_weights", - shape=[2, bert_config.hidden_size], - initializer=modeling.create_initializer(bert_config.initializer_range)) - output_bias = tf.compat.v1.get_variable( - "output_bias", shape=[2], initializer=tf.compat.v1.zeros_initializer()) - - logits = tf.matmul(input_tensor, output_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - log_probs = tf.nn.log_softmax(logits, axis=-1) - labels = tf.reshape(labels, [-1]) - one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) - per_example_loss = -tf.reduce_sum(input_tensor=one_hot_labels * log_probs, axis=-1) - loss = tf.reduce_mean(input_tensor=per_example_loss) - return (loss, per_example_loss, log_probs) + """Get loss and log probs for the next sentence prediction.""" + + # Simple binary classification. Note that 0 is "next sentence" and 1 is + # "random sentence". This weight matrix is not used after pre-training. + with tf.compat.v1.variable_scope("cls/seq_relationship"): + output_weights = tf.compat.v1.get_variable( + "output_weights", + shape=[2, bert_config.hidden_size], + initializer=modeling.create_initializer(bert_config.initializer_range), + ) + output_bias = tf.compat.v1.get_variable( + "output_bias", shape=[2], initializer=tf.compat.v1.zeros_initializer() + ) + + logits = tf.matmul(input_tensor, output_weights, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + log_probs = tf.nn.log_softmax(logits, axis=-1) + labels = tf.reshape(labels, [-1]) + one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) + per_example_loss = -tf.reduce_sum( + input_tensor=one_hot_labels * log_probs, axis=-1 + ) + loss = tf.reduce_mean(input_tensor=per_example_loss) + return (loss, per_example_loss, log_probs) def gather_indexes(sequence_tensor, positions): - """Gathers the vectors at the specific positions over a minibatch.""" - sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) - batch_size = sequence_shape[0] - seq_length = sequence_shape[1] - width = sequence_shape[2] - - flat_offsets = tf.reshape( - tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) - flat_positions = tf.reshape(positions + flat_offsets, [-1]) - flat_sequence_tensor = tf.reshape(sequence_tensor, - [batch_size * seq_length, width]) - output_tensor = tf.gather(flat_sequence_tensor, flat_positions) - return output_tensor - - -def input_fn_builder(input_files, - max_seq_length, - max_predictions_per_seq, - is_training, - num_cpu_threads=4): - """Creates an `input_fn` closure to be passed to TPUEstimator.""" - - def input_fn(params): - """The actual input function.""" - batch_size = params["batch_size"] - - name_to_features = { - "input_ids": - tf.io.FixedLenFeature([max_seq_length], tf.int64), - "input_mask": - tf.io.FixedLenFeature([max_seq_length], tf.int64), - "segment_ids": - tf.io.FixedLenFeature([max_seq_length], tf.int64), - "masked_lm_positions": - tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), - "masked_lm_ids": - tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), - "masked_lm_weights": - tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32), - "next_sentence_labels": - tf.io.FixedLenFeature([1], tf.int64), - } - - # For training, we want a lot of parallel reading and shuffling. - # For eval, we want no shuffling and parallel reading doesn't matter. - if is_training: - d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) - d = d.repeat() - d = d.shuffle(buffer_size=len(input_files)) - - # `cycle_length` is the number of parallel files that get read. - cycle_length = min(num_cpu_threads, len(input_files)) - - # `sloppy` mode means that the interleaving is not exact. This adds - # even more randomness to the training pipeline. - d = d.apply( - tf.data.experimental.parallel_interleave( - tf.data.TFRecordDataset, - sloppy=is_training, - cycle_length=cycle_length)) - d = d.shuffle(buffer_size=100) - else: - d = tf.data.TFRecordDataset(input_files) - # Since we evaluate for a fixed number of steps we don't want to encounter - # out-of-range exceptions. - d = d.repeat() - - # We must `drop_remainder` on training because the TPU requires fixed - # size dimensions. For eval, we assume we are evaluating on the CPU or GPU - # and we *don't* want to drop the remainder, otherwise we wont cover - # every sample. - d = d.apply( - tf.data.experimental.map_and_batch( - lambda record: _decode_record(record, name_to_features), - batch_size=batch_size, - num_parallel_batches=num_cpu_threads, - drop_remainder=True)) - return d - - return input_fn + """Gathers the vectors at the specific positions over a minibatch.""" + sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) + batch_size = sequence_shape[0] + seq_length = sequence_shape[1] + width = sequence_shape[2] + + flat_offsets = tf.reshape( + tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1] + ) + flat_positions = tf.reshape(positions + flat_offsets, [-1]) + flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) + output_tensor = tf.gather(flat_sequence_tensor, flat_positions) + return output_tensor + + +def input_fn_builder( + input_files, max_seq_length, max_predictions_per_seq, is_training, num_cpu_threads=4 +): + """Creates an `input_fn` closure to be passed to TPUEstimator.""" + + def input_fn(params): + """The actual input function.""" + batch_size = params["batch_size"] + + name_to_features = { + "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64), + "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64), + "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64), + "masked_lm_positions": tf.io.FixedLenFeature( + [max_predictions_per_seq], tf.int64 + ), + "masked_lm_ids": tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64), + "masked_lm_weights": tf.io.FixedLenFeature( + [max_predictions_per_seq], tf.float32 + ), + "next_sentence_labels": tf.io.FixedLenFeature([1], tf.int64), + } + + # For training, we want a lot of parallel reading and shuffling. + # For eval, we want no shuffling and parallel reading doesn't matter. + if is_training: + d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) + d = d.repeat() + d = d.shuffle(buffer_size=len(input_files)) + + # `cycle_length` is the number of parallel files that get read. + cycle_length = min(num_cpu_threads, len(input_files)) + + # `sloppy` mode means that the interleaving is not exact. This adds + # even more randomness to the training pipeline. + d = d.apply( + tf.data.experimental.parallel_interleave( + tf.data.TFRecordDataset, + sloppy=is_training, + cycle_length=cycle_length, + ) + ) + d = d.shuffle(buffer_size=100) + else: + d = tf.data.TFRecordDataset(input_files) + # Since we evaluate for a fixed number of steps we don't want to encounter + # out-of-range exceptions. + d = d.repeat() + + # We must `drop_remainder` on training because the TPU requires fixed + # size dimensions. For eval, we assume we are evaluating on the CPU or GPU + # and we *don't* want to drop the remainder, otherwise we wont cover + # every sample. + d = d.apply( + tf.data.experimental.map_and_batch( + lambda record: _decode_record(record, name_to_features), + batch_size=batch_size, + num_parallel_batches=num_cpu_threads, + drop_remainder=True, + ) + ) + return d + + return input_fn def _decode_record(record, name_to_features): - """Decodes a record to a TensorFlow example.""" - example = tf.io.parse_single_example(serialized=record, features=name_to_features) + """Decodes a record to a TensorFlow example.""" + example = tf.io.parse_single_example(serialized=record, features=name_to_features) - # tf.Example only supports tf.int64, but the TPU only supports tf.int32. - # So cast all int64 to int32. - for name in list(example.keys()): - t = example[name] - if t.dtype == tf.int64: - t = tf.cast(t, dtype=tf.int32) - example[name] = t + # tf.Example only supports tf.int64, but the TPU only supports tf.int32. + # So cast all int64 to int32. + for name in list(example.keys()): + t = example[name] + if t.dtype == tf.int64: + t = tf.cast(t, dtype=tf.int32) + example[name] = t - return example + return example def main(_): - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - - if not FLAGS.do_train and not FLAGS.do_eval: - raise ValueError("At least one of `do_train` or `do_eval` must be True.") - - bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) - - tf.io.gfile.makedirs(FLAGS.output_dir) - - input_files = [] - for input_pattern in FLAGS.input_file.split(","): - input_files.extend(tf.io.gfile.glob(input_pattern)) - - tf.compat.v1.logging.info("*** Input Files ***") - for input_file in input_files: - tf.compat.v1.logging.info(" %s" % input_file) - - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.compat.v1.estimator.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, - model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - - model_fn = model_fn_builder( - bert_config=bert_config, - init_checkpoint=FLAGS.init_checkpoint, - learning_rate=FLAGS.learning_rate, - num_train_steps=FLAGS.num_train_steps, - num_warmup_steps=FLAGS.num_warmup_steps, - use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = tf.compat.v1.estimator.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, - model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size) - - if FLAGS.do_train: - tf.compat.v1.logging.info("***** Running training *****") - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) - train_input_fn = input_fn_builder( - input_files=input_files, - max_seq_length=FLAGS.max_seq_length, - max_predictions_per_seq=FLAGS.max_predictions_per_seq, - is_training=True) - estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) - - if FLAGS.do_eval: - tf.compat.v1.logging.info("***** Running evaluation *****") - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) - - eval_input_fn = input_fn_builder( - input_files=input_files, - max_seq_length=FLAGS.max_seq_length, - max_predictions_per_seq=FLAGS.max_predictions_per_seq, - is_training=False) - - result = estimator.evaluate( - input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) - - output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") - with tf.io.gfile.GFile(output_eval_file, "w") as writer: - tf.compat.v1.logging.info("***** Eval results *****") - for key in sorted(result.keys()): - tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + if not FLAGS.do_train and not FLAGS.do_eval: + raise ValueError("At least one of `do_train` or `do_eval` must be True.") + + bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) + + tf.io.gfile.makedirs(FLAGS.output_dir) + + input_files = [] + for input_pattern in FLAGS.input_file.split(","): + input_files.extend(tf.io.gfile.glob(input_pattern)) + + tf.compat.v1.logging.info("*** Input Files ***") + for input_file in input_files: + tf.compat.v1.logging.info(" %s" % input_file) + + tpu_cluster_resolver = None + if FLAGS.use_tpu and FLAGS.tpu_name: + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project + ) + + is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 + run_config = tf.compat.v1.estimator.tpu.RunConfig( + cluster=tpu_cluster_resolver, + master=FLAGS.master, + model_dir=FLAGS.output_dir, + save_checkpoints_steps=FLAGS.save_checkpoints_steps, + tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_tpu_cores, + per_host_input_for_training=is_per_host, + ), + ) + + model_fn = model_fn_builder( + bert_config=bert_config, + init_checkpoint=FLAGS.init_checkpoint, + learning_rate=FLAGS.learning_rate, + num_train_steps=FLAGS.num_train_steps, + num_warmup_steps=FLAGS.num_warmup_steps, + use_tpu=FLAGS.use_tpu, + use_one_hot_embeddings=FLAGS.use_tpu, + ) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = tf.compat.v1.estimator.tpu.TPUEstimator( + use_tpu=FLAGS.use_tpu, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS.train_batch_size, + eval_batch_size=FLAGS.eval_batch_size, + ) + + if FLAGS.do_train: + tf.compat.v1.logging.info("***** Running training *****") + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) + train_input_fn = input_fn_builder( + input_files=input_files, + max_seq_length=FLAGS.max_seq_length, + max_predictions_per_seq=FLAGS.max_predictions_per_seq, + is_training=True, + ) + estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) + + if FLAGS.do_eval: + tf.compat.v1.logging.info("***** Running evaluation *****") + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) + + eval_input_fn = input_fn_builder( + input_files=input_files, + max_seq_length=FLAGS.max_seq_length, + max_predictions_per_seq=FLAGS.max_predictions_per_seq, + is_training=False, + ) + + result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) + + output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") + with tf.io.gfile.GFile(output_eval_file, "w") as writer: + tf.compat.v1.logging.info("***** Eval results *****") + for key in sorted(result.keys()): + tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) if __name__ == "__main__": - flags.mark_flag_as_required("input_file") - flags.mark_flag_as_required("bert_config_file") - flags.mark_flag_as_required("output_dir") - tf.compat.v1.app.run() + flags.mark_flag_as_required("input_file") + flags.mark_flag_as_required("bert_config_file") + flags.mark_flag_as_required("output_dir") + tf.compat.v1.app.run() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_squad.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_squad.py index a325d364..96e6d7af 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_squad.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/run_squad.py @@ -31,55 +31,73 @@ from absl import flags import sys -#flags = tf.flags +# flags = tf.flags FLAGS = flags.FLAGS ## Required parameters flags.DEFINE_string( - "bert_config_file", None, + "bert_config_file", + None, "The config json file corresponding to the pre-trained BERT model. " - "This specifies the model architecture.") + "This specifies the model architecture.", +) -flags.DEFINE_string("vocab_file", None, - "The vocabulary file that the BERT model was trained on.") +flags.DEFINE_string( + "vocab_file", None, "The vocabulary file that the BERT model was trained on." +) flags.DEFINE_string( - "output_dir", None, - "The output directory where the model checkpoints will be written.") + "output_dir", + None, + "The output directory where the model checkpoints will be written.", +) ## Other parameters -flags.DEFINE_string("train_file", None, - "SQuAD json for training. E.g., train-v1.1.json") +flags.DEFINE_string( + "train_file", None, "SQuAD json for training. E.g., train-v1.1.json" +) flags.DEFINE_string( - "predict_file", None, - "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") + "predict_file", + None, + "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json", +) flags.DEFINE_string( - "init_checkpoint", None, - "Initial checkpoint (usually from a pre-trained BERT model).") + "init_checkpoint", + None, + "Initial checkpoint (usually from a pre-trained BERT model).", +) flags.DEFINE_bool( - "do_lower_case", True, + "do_lower_case", + True, "Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.") + "models and False for cased models.", +) flags.DEFINE_integer( - "max_seq_length", 384, + "max_seq_length", + 384, "The maximum total input sequence length after WordPiece tokenization. " "Sequences longer than this will be truncated, and sequences shorter " - "than this will be padded.") + "than this will be padded.", +) flags.DEFINE_integer( - "doc_stride", 128, + "doc_stride", + 128, "When splitting up a long document into chunks, how much stride to " - "take between chunks.") + "take between chunks.", +) flags.DEFINE_integer( - "max_query_length", 64, + "max_query_length", + 64, "The maximum number of tokens for the question. Questions longer than " - "this will be truncated to this length.") + "this will be truncated to this length.", +) flags.DEFINE_bool("do_train", False, "Whether to run training.") @@ -87,1200 +105,1341 @@ flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") -flags.DEFINE_integer("predict_batch_size", 8, - "Total batch size for predictions.") +flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predictions.") flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") -flags.DEFINE_float("num_train_epochs", 3.0, - "Total number of training epochs to perform.") +flags.DEFINE_float( + "num_train_epochs", 3.0, "Total number of training epochs to perform." +) flags.DEFINE_float( - "warmup_proportion", 0.1, + "warmup_proportion", + 0.1, "Proportion of training to perform linear learning rate warmup for. " - "E.g., 0.1 = 10% of training.") + "E.g., 0.1 = 10% of training.", +) -flags.DEFINE_integer("save_checkpoints_steps", 1000, - "How often to save the model checkpoint.") +flags.DEFINE_integer( + "save_checkpoints_steps", 1000, "How often to save the model checkpoint." +) -flags.DEFINE_integer("iterations_per_loop", 1000, - "How many steps to make in each estimator call.") +flags.DEFINE_integer( + "iterations_per_loop", 1000, "How many steps to make in each estimator call." +) flags.DEFINE_integer( - "n_best_size", 20, + "n_best_size", + 20, "The total number of n-best predictions to generate in the " - "nbest_predictions.json output file.") + "nbest_predictions.json output file.", +) flags.DEFINE_integer( - "max_answer_length", 30, + "max_answer_length", + 30, "The maximum length of an answer that can be generated. This is needed " - "because the start and end predictions are not conditioned on one another.") + "because the start and end predictions are not conditioned on one another.", +) flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") flags.DEFINE_string( - "tpu_name", None, + "tpu_name", + None, "The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " - "url.") + "url.", +) flags.DEFINE_string( - "tpu_zone", None, + "tpu_zone", + None, "[Optional] GCE zone where the Cloud TPU is located in. If not " "specified, we will attempt to automatically detect the GCE project from " - "metadata.") + "metadata.", +) flags.DEFINE_string( - "gcp_project", None, + "gcp_project", + None, "[Optional] Project name for the Cloud TPU-enabled project. If not " "specified, we will attempt to automatically detect the GCE project from " - "metadata.") + "metadata.", +) flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") flags.DEFINE_integer( - "num_tpu_cores", 8, - "Only used if `use_tpu` is True. Total number of TPU cores to use.") + "num_tpu_cores", + 8, + "Only used if `use_tpu` is True. Total number of TPU cores to use.", +) flags.DEFINE_bool( - "verbose_logging", False, + "verbose_logging", + False, "If true, all of the warnings related to data processing will be printed. " - "A number of warnings are expected for a normal SQuAD evaluation.") + "A number of warnings are expected for a normal SQuAD evaluation.", +) flags.DEFINE_bool( - "version_2_with_negative", False, - "If true, the SQuAD examples contain some that do not have an answer.") + "version_2_with_negative", + False, + "If true, the SQuAD examples contain some that do not have an answer.", +) flags.DEFINE_float( - "null_score_diff_threshold", 0.0, - "If null_score - best_non_null is greater than the threshold predict null.") + "null_score_diff_threshold", + 0.0, + "If null_score - best_non_null is greater than the threshold predict null.", +) class SquadExample(object): - """A single training/test example for simple sequence classification. - - For examples without an answer, the start and end position are -1. - """ - - def __init__(self, - qas_id, - question_text, - doc_tokens, - orig_answer_text=None, - start_position=None, - end_position=None, - is_impossible=False): - self.qas_id = qas_id - self.question_text = question_text - self.doc_tokens = doc_tokens - self.orig_answer_text = orig_answer_text - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible - - def __str__(self): - return self.__repr__() - - def __repr__(self): - s = "" - s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) - s += ", question_text: %s" % ( - tokenization.printable_text(self.question_text)) - s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) - if self.start_position: - s += ", start_position: %d" % (self.start_position) - if self.start_position: - s += ", end_position: %d" % (self.end_position) - if self.start_position: - s += ", is_impossible: %r" % (self.is_impossible) - return s + """A single training/test example for simple sequence classification. + + For examples without an answer, the start and end position are -1. + """ + + def __init__( + self, + qas_id, + question_text, + doc_tokens, + orig_answer_text=None, + start_position=None, + end_position=None, + is_impossible=False, + ): + self.qas_id = qas_id + self.question_text = question_text + self.doc_tokens = doc_tokens + self.orig_answer_text = orig_answer_text + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = "" + s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) + s += ", question_text: %s" % (tokenization.printable_text(self.question_text)) + s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) + if self.start_position: + s += ", start_position: %d" % (self.start_position) + if self.start_position: + s += ", end_position: %d" % (self.end_position) + if self.start_position: + s += ", is_impossible: %r" % (self.is_impossible) + return s class InputFeatures(object): - """A single set of features of data.""" - - def __init__(self, - unique_id, - example_index, - doc_span_index, - tokens, - token_to_orig_map, - token_is_max_context, - input_ids, - input_mask, - segment_ids, - start_position=None, - end_position=None, - is_impossible=None): - self.unique_id = unique_id - self.example_index = example_index - self.doc_span_index = doc_span_index - self.tokens = tokens - self.token_to_orig_map = token_to_orig_map - self.token_is_max_context = token_is_max_context - self.input_ids = input_ids - self.input_mask = input_mask - self.segment_ids = segment_ids - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible + """A single set of features of data.""" + + def __init__( + self, + unique_id, + example_index, + doc_span_index, + tokens, + token_to_orig_map, + token_is_max_context, + input_ids, + input_mask, + segment_ids, + start_position=None, + end_position=None, + is_impossible=None, + ): + self.unique_id = unique_id + self.example_index = example_index + self.doc_span_index = doc_span_index + self.tokens = tokens + self.token_to_orig_map = token_to_orig_map + self.token_is_max_context = token_is_max_context + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible def read_squad_examples(input_file, is_training): - """Read a SQuAD json file into a list of SquadExample.""" - with tf.io.gfile.GFile(input_file, "r") as reader: - input_data = json.load(reader)["data"] - - def is_whitespace(c): - if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: - return True - return False - - examples = [] - for entry in input_data: - for paragraph in entry["paragraphs"]: - paragraph_text = paragraph["context"] - doc_tokens = [] - char_to_word_offset = [] - prev_is_whitespace = True - for c in paragraph_text: - if is_whitespace(c): - prev_is_whitespace = True - else: - if prev_is_whitespace: - doc_tokens.append(c) - else: - doc_tokens[-1] += c - prev_is_whitespace = False - char_to_word_offset.append(len(doc_tokens) - 1) - - for qa in paragraph["qas"]: - qas_id = qa["id"] - question_text = qa["question"] - start_position = None - end_position = None - orig_answer_text = None - is_impossible = False - if is_training: - - #if FLAGS.version_2_with_negative: - # is_impossible = qa["is_impossible"] - if (len(qa["answers"]) != 1) and (not is_impossible): - raise ValueError( - "For training, each question should have exactly 1 answer.") - if not is_impossible: - answer = qa["answers"][0] - orig_answer_text = answer["text"] - answer_offset = answer["answer_start"] - answer_length = len(orig_answer_text) - start_position = char_to_word_offset[answer_offset] - end_position = char_to_word_offset[answer_offset + answer_length - - 1] - # Only add answers where the text can be exactly recovered from the - # document. If this CAN'T happen it's likely due to weird Unicode - # stuff so we will just skip the example. - # - # Note that this means for training mode, every example is NOT - # guaranteed to be preserved. - actual_text = " ".join( - doc_tokens[start_position:(end_position + 1)]) - cleaned_answer_text = " ".join( - tokenization.whitespace_tokenize(orig_answer_text)) - if actual_text.find(cleaned_answer_text) == -1: - tf.compat.v1.logging.warning("Could not find answer: '%s' vs. '%s'", - actual_text, cleaned_answer_text) - continue - else: - start_position = -1 - end_position = -1 - orig_answer_text = "" - - example = SquadExample( - qas_id=qas_id, - question_text=question_text, - doc_tokens=doc_tokens, - orig_answer_text=orig_answer_text, - start_position=start_position, - end_position=end_position, - is_impossible=is_impossible) - examples.append(example) - - return examples - - -def convert_examples_to_features(examples, tokenizer, max_seq_length, - doc_stride, max_query_length, is_training, - output_fn): - """Loads a data file into a list of `InputBatch`s.""" - - unique_id = 1000000000 - - for (example_index, example) in enumerate(examples): - query_tokens = tokenizer.tokenize(example.question_text) - - if len(query_tokens) > max_query_length: - query_tokens = query_tokens[0:max_query_length] - - tok_to_orig_index = [] - orig_to_tok_index = [] - all_doc_tokens = [] - for (i, token) in enumerate(example.doc_tokens): - orig_to_tok_index.append(len(all_doc_tokens)) - sub_tokens = tokenizer.tokenize(token) - for sub_token in sub_tokens: - tok_to_orig_index.append(i) - all_doc_tokens.append(sub_token) - - tok_start_position = None - tok_end_position = None - if is_training and example.is_impossible: - tok_start_position = -1 - tok_end_position = -1 - if is_training and not example.is_impossible: - tok_start_position = orig_to_tok_index[example.start_position] - if example.end_position < len(example.doc_tokens) - 1: - tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 - else: - tok_end_position = len(all_doc_tokens) - 1 - (tok_start_position, tok_end_position) = _improve_answer_span( - all_doc_tokens, tok_start_position, tok_end_position, tokenizer, - example.orig_answer_text) - - # The -3 accounts for [CLS], [SEP] and [SEP] - max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 - - # We can have documents that are longer than the maximum sequence length. - # To deal with this we do a sliding window approach, where we take chunks - # of the up to our max length with a stride of `doc_stride`. - _DocSpan = collections.namedtuple( # pylint: disable=invalid-name - "DocSpan", ["start", "length"]) - doc_spans = [] - start_offset = 0 - while start_offset < len(all_doc_tokens): - length = len(all_doc_tokens) - start_offset - if length > max_tokens_for_doc: - length = max_tokens_for_doc - doc_spans.append(_DocSpan(start=start_offset, length=length)) - if start_offset + length == len(all_doc_tokens): - break - start_offset += min(length, doc_stride) - - for (doc_span_index, doc_span) in enumerate(doc_spans): - tokens = [] - token_to_orig_map = {} - token_is_max_context = {} - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in query_tokens: - tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) - - for i in range(doc_span.length): - split_token_index = doc_span.start + i - token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] - - is_max_context = _check_is_max_context(doc_spans, doc_span_index, - split_token_index) - token_is_max_context[len(tokens)] = is_max_context - tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) - - input_ids = tokenizer.convert_tokens_to_ids(tokens) - - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1] * len(input_ids) - - # Zero-pad up to the sequence length. - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) - - assert len(input_ids) == max_seq_length - assert len(input_mask) == max_seq_length - assert len(segment_ids) == max_seq_length - - start_position = None - end_position = None - if is_training and not example.is_impossible: - # For training, if our document chunk does not contain an annotation - # we throw it out, since there is nothing to predict. - doc_start = doc_span.start - doc_end = doc_span.start + doc_span.length - 1 - out_of_span = False - if not (tok_start_position >= doc_start and - tok_end_position <= doc_end): - out_of_span = True - if out_of_span: - start_position = 0 - end_position = 0 - else: - doc_offset = len(query_tokens) + 2 - start_position = tok_start_position - doc_start + doc_offset - end_position = tok_end_position - doc_start + doc_offset - - if is_training and example.is_impossible: - start_position = 0 - end_position = 0 - - if example_index < 20: - tf.compat.v1.logging.info("*** Example ***") - tf.compat.v1.logging.info("unique_id: %s" % (unique_id)) - tf.compat.v1.logging.info("example_index: %s" % (example_index)) - tf.compat.v1.logging.info("doc_span_index: %s" % (doc_span_index)) - tf.compat.v1.logging.info("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in tokens])) - tf.compat.v1.logging.info("token_to_orig_map: %s" % " ".join( - ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) - tf.compat.v1.logging.info("token_is_max_context: %s" % " ".join([ - "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) - ])) - tf.compat.v1.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - tf.compat.v1.logging.info( - "input_mask: %s" % " ".join([str(x) for x in input_mask])) - tf.compat.v1.logging.info( - "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + """Read a SQuAD json file into a list of SquadExample.""" + with tf.io.gfile.GFile(input_file, "r") as reader: + input_data = json.load(reader)["data"] + + def is_whitespace(c): + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + return True + return False + + examples = [] + for entry in input_data: + for paragraph in entry["paragraphs"]: + paragraph_text = paragraph["context"] + doc_tokens = [] + char_to_word_offset = [] + prev_is_whitespace = True + for c in paragraph_text: + if is_whitespace(c): + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + char_to_word_offset.append(len(doc_tokens) - 1) + + for qa in paragraph["qas"]: + qas_id = qa["id"] + question_text = qa["question"] + start_position = None + end_position = None + orig_answer_text = None + is_impossible = False + if is_training: + + # if FLAGS.version_2_with_negative: + # is_impossible = qa["is_impossible"] + if (len(qa["answers"]) != 1) and (not is_impossible): + raise ValueError( + "For training, each question should have exactly 1 answer." + ) + if not is_impossible: + answer = qa["answers"][0] + orig_answer_text = answer["text"] + answer_offset = answer["answer_start"] + answer_length = len(orig_answer_text) + start_position = char_to_word_offset[answer_offset] + end_position = char_to_word_offset[ + answer_offset + answer_length - 1 + ] + # Only add answers where the text can be exactly recovered from the + # document. If this CAN'T happen it's likely due to weird Unicode + # stuff so we will just skip the example. + # + # Note that this means for training mode, every example is NOT + # guaranteed to be preserved. + actual_text = " ".join( + doc_tokens[start_position : (end_position + 1)] + ) + cleaned_answer_text = " ".join( + tokenization.whitespace_tokenize(orig_answer_text) + ) + if actual_text.find(cleaned_answer_text) == -1: + tf.compat.v1.logging.warning( + "Could not find answer: '%s' vs. '%s'", + actual_text, + cleaned_answer_text, + ) + continue + else: + start_position = -1 + end_position = -1 + orig_answer_text = "" + + example = SquadExample( + qas_id=qas_id, + question_text=question_text, + doc_tokens=doc_tokens, + orig_answer_text=orig_answer_text, + start_position=start_position, + end_position=end_position, + is_impossible=is_impossible, + ) + examples.append(example) + + return examples + + +def convert_examples_to_features( + examples, + tokenizer, + max_seq_length, + doc_stride, + max_query_length, + is_training, + output_fn, +): + """Loads a data file into a list of `InputBatch`s.""" + + unique_id = 1000000000 + + for example_index, example in enumerate(examples): + query_tokens = tokenizer.tokenize(example.question_text) + + if len(query_tokens) > max_query_length: + query_tokens = query_tokens[0:max_query_length] + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for i, token in enumerate(example.doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + tok_start_position = None + tok_end_position = None if is_training and example.is_impossible: - tf.compat.v1.logging.info("impossible example") + tok_start_position = -1 + tok_end_position = -1 if is_training and not example.is_impossible: - answer_text = " ".join(tokens[start_position:(end_position + 1)]) - tf.compat.v1.logging.info("start_position: %d" % (start_position)) - tf.compat.v1.logging.info("end_position: %d" % (end_position)) - tf.compat.v1.logging.info( - "answer: %s" % (tokenization.printable_text(answer_text))) - - feature = InputFeatures( - unique_id=unique_id, - example_index=example_index, - doc_span_index=doc_span_index, - tokens=tokens, - token_to_orig_map=token_to_orig_map, - token_is_max_context=token_is_max_context, - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - start_position=start_position, - end_position=end_position, - is_impossible=example.is_impossible) - - # Run callback - output_fn(feature) - - unique_id += 1 - - -def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, - orig_answer_text): - """Returns tokenized answer spans that better match the annotated answer.""" - - # The SQuAD annotations are character based. We first project them to - # whitespace-tokenized words. But then after WordPiece tokenization, we can - # often find a "better match". For example: - # - # Question: What year was John Smith born? - # Context: The leader was John Smith (1895-1943). - # Answer: 1895 - # - # The original whitespace-tokenized answer will be "(1895-1943).". However - # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match - # the exact answer, 1895. - # - # However, this is not always possible. Consider the following: - # - # Question: What country is the top exporter of electornics? - # Context: The Japanese electronics industry is the lagest in the world. - # Answer: Japan - # - # In this case, the annotator chose "Japan" as a character sub-span of - # the word "Japanese". Since our WordPiece tokenizer does not split - # "Japanese", we just use "Japanese" as the annotation. This is fairly rare - # in SQuAD, but does happen. - tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) - - for new_start in range(input_start, input_end + 1): - for new_end in range(input_end, new_start - 1, -1): - text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) - if text_span == tok_answer_text: - return (new_start, new_end) - - return (input_start, input_end) + tok_start_position = orig_to_tok_index[example.start_position] + if example.end_position < len(example.doc_tokens) - 1: + tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 + else: + tok_end_position = len(all_doc_tokens) - 1 + tok_start_position, tok_end_position = _improve_answer_span( + all_doc_tokens, + tok_start_position, + tok_end_position, + tokenizer, + example.orig_answer_text, + ) + + # The -3 accounts for [CLS], [SEP] and [SEP] + max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 + + # We can have documents that are longer than the maximum sequence length. + # To deal with this we do a sliding window approach, where we take chunks + # of the up to our max length with a stride of `doc_stride`. + _DocSpan = collections.namedtuple( # pylint: disable=invalid-name + "DocSpan", ["start", "length"] + ) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, doc_stride) + + for doc_span_index, doc_span in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in query_tokens: + tokens.append(token) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + + for i in range(doc_span.length): + split_token_index = doc_span.start + i + token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] + + is_max_context = _check_is_max_context( + doc_spans, doc_span_index, split_token_index + ) + token_is_max_context[len(tokens)] = is_max_context + tokens.append(all_doc_tokens[split_token_index]) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + start_position = None + end_position = None + if is_training and not example.is_impossible: + # For training, if our document chunk does not contain an annotation + # we throw it out, since there is nothing to predict. + doc_start = doc_span.start + doc_end = doc_span.start + doc_span.length - 1 + out_of_span = False + if not ( + tok_start_position >= doc_start and tok_end_position <= doc_end + ): + out_of_span = True + if out_of_span: + start_position = 0 + end_position = 0 + else: + doc_offset = len(query_tokens) + 2 + start_position = tok_start_position - doc_start + doc_offset + end_position = tok_end_position - doc_start + doc_offset + + if is_training and example.is_impossible: + start_position = 0 + end_position = 0 + + if example_index < 20: + tf.compat.v1.logging.info("*** Example ***") + tf.compat.v1.logging.info("unique_id: %s" % (unique_id)) + tf.compat.v1.logging.info("example_index: %s" % (example_index)) + tf.compat.v1.logging.info("doc_span_index: %s" % (doc_span_index)) + tf.compat.v1.logging.info( + "tokens: %s" + % " ".join([tokenization.printable_text(x) for x in tokens]) + ) + tf.compat.v1.logging.info( + "token_to_orig_map: %s" + % " ".join( + [ + "%d:%d" % (x, y) + for (x, y) in six.iteritems(token_to_orig_map) + ] + ) + ) + tf.compat.v1.logging.info( + "token_is_max_context: %s" + % " ".join( + [ + "%d:%s" % (x, y) + for (x, y) in six.iteritems(token_is_max_context) + ] + ) + ) + tf.compat.v1.logging.info( + "input_ids: %s" % " ".join([str(x) for x in input_ids]) + ) + tf.compat.v1.logging.info( + "input_mask: %s" % " ".join([str(x) for x in input_mask]) + ) + tf.compat.v1.logging.info( + "segment_ids: %s" % " ".join([str(x) for x in segment_ids]) + ) + if is_training and example.is_impossible: + tf.compat.v1.logging.info("impossible example") + if is_training and not example.is_impossible: + answer_text = " ".join(tokens[start_position : (end_position + 1)]) + tf.compat.v1.logging.info("start_position: %d" % (start_position)) + tf.compat.v1.logging.info("end_position: %d" % (end_position)) + tf.compat.v1.logging.info( + "answer: %s" % (tokenization.printable_text(answer_text)) + ) + + feature = InputFeatures( + unique_id=unique_id, + example_index=example_index, + doc_span_index=doc_span_index, + tokens=tokens, + token_to_orig_map=token_to_orig_map, + token_is_max_context=token_is_max_context, + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + start_position=start_position, + end_position=end_position, + is_impossible=example.is_impossible, + ) + + # Run callback + output_fn(feature) + + unique_id += 1 + + +def _improve_answer_span( + doc_tokens, input_start, input_end, tokenizer, orig_answer_text +): + """Returns tokenized answer spans that better match the annotated answer.""" + + # The SQuAD annotations are character based. We first project them to + # whitespace-tokenized words. But then after WordPiece tokenization, we can + # often find a "better match". For example: + # + # Question: What year was John Smith born? + # Context: The leader was John Smith (1895-1943). + # Answer: 1895 + # + # The original whitespace-tokenized answer will be "(1895-1943).". However + # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match + # the exact answer, 1895. + # + # However, this is not always possible. Consider the following: + # + # Question: What country is the top exporter of electornics? + # Context: The Japanese electronics industry is the lagest in the world. + # Answer: Japan + # + # In this case, the annotator chose "Japan" as a character sub-span of + # the word "Japanese". Since our WordPiece tokenizer does not split + # "Japanese", we just use "Japanese" as the annotation. This is fairly rare + # in SQuAD, but does happen. + tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) + + for new_start in range(input_start, input_end + 1): + for new_end in range(input_end, new_start - 1, -1): + text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) + if text_span == tok_answer_text: + return (new_start, new_end) + + return (input_start, input_end) def _check_is_max_context(doc_spans, cur_span_index, position): - """Check if this is the 'max context' doc span for the token.""" - - # Because of the sliding window approach taken to scoring documents, a single - # token can appear in multiple documents. E.g. - # Doc: the man went to the store and bought a gallon of milk - # Span A: the man went to the - # Span B: to the store and bought - # Span C: and bought a gallon of - # ... - # - # Now the word 'bought' will have two scores from spans B and C. We only - # want to consider the score with "maximum context", which we define as - # the *minimum* of its left and right context (the *sum* of left and - # right context will always be the same, of course). - # - # In the example the maximum context for 'bought' would be span C since - # it has 1 left context and 3 right context, while span B has 4 left context - # and 0 right context. - best_score = None - best_span_index = None - for (span_index, doc_span) in enumerate(doc_spans): - end = doc_span.start + doc_span.length - 1 - if position < doc_span.start: - continue - if position > end: - continue - num_left_context = position - doc_span.start - num_right_context = end - position - score = min(num_left_context, num_right_context) + 0.01 * doc_span.length - if best_score is None or score > best_score: - best_score = score - best_span_index = span_index - - return cur_span_index == best_span_index - - -def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, - use_one_hot_embeddings): - """Creates a classification model.""" - model = modeling.BertModel( - config=bert_config, - is_training=is_training, - input_ids=input_ids, - input_mask=input_mask, - token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) - - final_hidden = model.get_sequence_output() - - final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) - batch_size = final_hidden_shape[0] - seq_length = final_hidden_shape[1] - hidden_size = final_hidden_shape[2] - - output_weights = tf.compat.v1.get_variable( - "cls/squad/output_weights", [2, hidden_size], - initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)) - - output_bias = tf.compat.v1.get_variable( - "cls/squad/output_bias", [2], initializer=tf.compat.v1.zeros_initializer()) - - final_hidden_matrix = tf.reshape(final_hidden, - [batch_size * seq_length, hidden_size]) - logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - - logits = tf.reshape(logits, [batch_size, seq_length, 2]) - logits = tf.transpose(a=logits, perm=[2, 0, 1]) - - unstacked_logits = tf.unstack(logits, axis=0) - - (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) - - return (start_logits, end_logits) - - -def model_fn_builder(bert_config, init_checkpoint, learning_rate, - num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings): - """Returns `model_fn` closure for TPUEstimator.""" - - def model_fn(features, labels, mode, params): # pylint: disable=unused-argument - """The `model_fn` for TPUEstimator.""" - - tf.compat.v1.logging.info("*** Features ***") - for name in sorted(features.keys()): - tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) - - unique_ids = features["unique_ids"] - input_ids = features["input_ids"] - input_mask = features["input_mask"] - segment_ids = features["segment_ids"] - - is_training = (mode == tf.estimator.ModeKeys.TRAIN) - - (start_logits, end_logits) = create_model( - bert_config=bert_config, + """Check if this is the 'max context' doc span for the token.""" + + # Because of the sliding window approach taken to scoring documents, a single + # token can appear in multiple documents. E.g. + # Doc: the man went to the store and bought a gallon of milk + # Span A: the man went to the + # Span B: to the store and bought + # Span C: and bought a gallon of + # ... + # + # Now the word 'bought' will have two scores from spans B and C. We only + # want to consider the score with "maximum context", which we define as + # the *minimum* of its left and right context (the *sum* of left and + # right context will always be the same, of course). + # + # In the example the maximum context for 'bought' would be span C since + # it has 1 left context and 3 right context, while span B has 4 left context + # and 0 right context. + best_score = None + best_span_index = None + for span_index, doc_span in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + continue + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + + +def create_model( + bert_config, is_training, input_ids, input_mask, segment_ids, use_one_hot_embeddings +): + """Creates a classification model.""" + model = modeling.BertModel( + config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, - segment_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) - - tvars = tf.compat.v1.trainable_variables() - - initialized_variable_names = {} - scaffold_fn = None - if init_checkpoint: - (assignment_map, initialized_variable_names - ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) - if use_tpu: - - def tpu_scaffold(): - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - return tf.compat.v1.train.Scaffold() - - scaffold_fn = tpu_scaffold - else: - tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) - - tf.compat.v1.logging.info("**** Trainable Variables ****") - for var in tvars: - init_string = "" - if var.name in initialized_variable_names: - init_string = ", *INIT_FROM_CKPT*" - tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, - init_string) - - output_spec = None - if mode == tf.estimator.ModeKeys.TRAIN: - seq_length = modeling.get_shape_list(input_ids)[1] - - def compute_loss(logits, positions): - one_hot_positions = tf.one_hot( - positions, depth=seq_length, dtype=tf.float32) - log_probs = tf.nn.log_softmax(logits, axis=-1) - loss = -tf.reduce_mean( - input_tensor=tf.reduce_sum(input_tensor=one_hot_positions * log_probs, axis=-1)) - return loss - - start_positions = features["start_positions"] - end_positions = features["end_positions"] - - start_loss = compute_loss(start_logits, start_positions) - end_loss = compute_loss(end_logits, end_positions) - - total_loss = (start_loss + end_loss) / 2.0 - - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) - - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - train_op=train_op, - scaffold_fn=scaffold_fn) - elif mode == tf.estimator.ModeKeys.PREDICT: - predictions = { - "unique_ids": unique_ids, - "start_logits": start_logits, - "end_logits": end_logits, - } - output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( - mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) - else: - raise ValueError( - "Only TRAIN and PREDICT modes are supported: %s" % (mode)) - - return output_spec - - return model_fn + token_type_ids=segment_ids, + use_one_hot_embeddings=use_one_hot_embeddings, + ) + + final_hidden = model.get_sequence_output() + + final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) + batch_size = final_hidden_shape[0] + seq_length = final_hidden_shape[1] + hidden_size = final_hidden_shape[2] + + output_weights = tf.compat.v1.get_variable( + "cls/squad/output_weights", + [2, hidden_size], + initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02), + ) + + output_bias = tf.compat.v1.get_variable( + "cls/squad/output_bias", [2], initializer=tf.compat.v1.zeros_initializer() + ) + + final_hidden_matrix = tf.reshape( + final_hidden, [batch_size * seq_length, hidden_size] + ) + logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) + logits = tf.nn.bias_add(logits, output_bias) + + logits = tf.reshape(logits, [batch_size, seq_length, 2]) + logits = tf.transpose(a=logits, perm=[2, 0, 1]) + + unstacked_logits = tf.unstack(logits, axis=0) + + start_logits, end_logits = (unstacked_logits[0], unstacked_logits[1]) + + return (start_logits, end_logits) + + +def model_fn_builder( + bert_config, + init_checkpoint, + learning_rate, + num_train_steps, + num_warmup_steps, + use_tpu, + use_one_hot_embeddings, +): + """Returns `model_fn` closure for TPUEstimator.""" + + def model_fn(features, labels, mode, params): # pylint: disable=unused-argument + """The `model_fn` for TPUEstimator.""" + + tf.compat.v1.logging.info("*** Features ***") + for name in sorted(features.keys()): + tf.compat.v1.logging.info( + " name = %s, shape = %s" % (name, features[name].shape) + ) + + unique_ids = features["unique_ids"] + input_ids = features["input_ids"] + input_mask = features["input_mask"] + segment_ids = features["segment_ids"] + + is_training = mode == tf.estimator.ModeKeys.TRAIN + + start_logits, end_logits = create_model( + bert_config=bert_config, + is_training=is_training, + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + use_one_hot_embeddings=use_one_hot_embeddings, + ) + + tvars = tf.compat.v1.trainable_variables() + + initialized_variable_names = {} + scaffold_fn = None + if init_checkpoint: + assignment_map, initialized_variable_names = ( + modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) + ) + if use_tpu: + + def tpu_scaffold(): + tf.compat.v1.train.init_from_checkpoint( + init_checkpoint, assignment_map + ) + return tf.compat.v1.train.Scaffold() + + scaffold_fn = tpu_scaffold + else: + tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) + + tf.compat.v1.logging.info("**** Trainable Variables ****") + for var in tvars: + init_string = "" + if var.name in initialized_variable_names: + init_string = ", *INIT_FROM_CKPT*" + tf.compat.v1.logging.info( + " name = %s, shape = %s%s", var.name, var.shape, init_string + ) + + output_spec = None + if mode == tf.estimator.ModeKeys.TRAIN: + seq_length = modeling.get_shape_list(input_ids)[1] + + def compute_loss(logits, positions): + one_hot_positions = tf.one_hot( + positions, depth=seq_length, dtype=tf.float32 + ) + log_probs = tf.nn.log_softmax(logits, axis=-1) + loss = -tf.reduce_mean( + input_tensor=tf.reduce_sum( + input_tensor=one_hot_positions * log_probs, axis=-1 + ) + ) + return loss + + start_positions = features["start_positions"] + end_positions = features["end_positions"] + + start_loss = compute_loss(start_logits, start_positions) + end_loss = compute_loss(end_logits, end_positions) + + total_loss = (start_loss + end_loss) / 2.0 + + train_op = optimization.create_optimizer( + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu + ) + + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn + ) + elif mode == tf.estimator.ModeKeys.PREDICT: + predictions = { + "unique_ids": unique_ids, + "start_logits": start_logits, + "end_logits": end_logits, + } + output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( + mode=mode, predictions=predictions, scaffold_fn=scaffold_fn + ) + else: + raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) + + return output_spec + + return model_fn def input_fn_builder(input_file, seq_length, is_training, drop_remainder): - """Creates an `input_fn` closure to be passed to TPUEstimator.""" - - name_to_features = { - "unique_ids": tf.io.FixedLenFeature([], tf.int64), - "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), - "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64), - "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), - } - - if is_training: - name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64) - name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64) - - def _decode_record(record, name_to_features): - """Decodes a record to a TensorFlow example.""" - example = tf.io.parse_single_example(serialized=record, features=name_to_features) - - # tf.Example only supports tf.int64, but the TPU only supports tf.int32. - # So cast all int64 to int32. - for name in list(example.keys()): - t = example[name] - if t.dtype == tf.int64: - t = tf.cast(t, dtype=tf.int32) - example[name] = t - - return example - - def input_fn(params): - """The actual input function.""" - batch_size = params["batch_size"] - - # For training, we want a lot of parallel reading and shuffling. - # For eval, we want no shuffling and parallel reading doesn't matter. - d = tf.data.TFRecordDataset(input_file) + """Creates an `input_fn` closure to be passed to TPUEstimator.""" + + name_to_features = { + "unique_ids": tf.io.FixedLenFeature([], tf.int64), + "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), + "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64), + "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64), + } + if is_training: - d = d.repeat() - d = d.shuffle(buffer_size=100) - - d = d.apply( - tf.data.experimental.map_and_batch( - lambda record: _decode_record(record, name_to_features), - batch_size=batch_size, - drop_remainder=drop_remainder)) - - return d - - return input_fn - - -RawResult = collections.namedtuple("RawResult", - ["unique_id", "start_logits", "end_logits"]) - - -def write_predictions(all_examples, all_features, all_results, n_best_size, - max_answer_length, do_lower_case, output_prediction_file, - output_nbest_file, output_null_log_odds_file): - """Write final predictions to the json file and log-odds of null if needed.""" - tf.compat.v1.logging.info("Writing predictions to: %s" % (output_prediction_file)) - tf.compat.v1.logging.info("Writing nbest to: %s" % (output_nbest_file)) - - example_index_to_features = collections.defaultdict(list) - for feature in all_features: - example_index_to_features[feature.example_index].append(feature) - - unique_id_to_result = {} - for result in all_results: - unique_id_to_result[result.unique_id] = result - - _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name - "PrelimPrediction", - ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) - - all_predictions = collections.OrderedDict() - all_nbest_json = collections.OrderedDict() - scores_diff_json = collections.OrderedDict() - - for (example_index, example) in enumerate(all_examples): - features = example_index_to_features[example_index] - - prelim_predictions = [] - # keep track of the minimum score of null start+end of position 0 - score_null = 1000000 # large and positive - min_null_feature_index = 0 # the paragraph slice with min mull score - null_start_logit = 0 # the start logit at the slice with min null score - null_end_logit = 0 # the end logit at the slice with min null score - for (feature_index, feature) in enumerate(features): - result = unique_id_to_result[feature.unique_id] - start_indexes = _get_best_indexes(result.start_logits, n_best_size) - end_indexes = _get_best_indexes(result.end_logits, n_best_size) - # if we could have irrelevant answers, get the min score of irrelevant - #if FLAGS.version_2_with_negative: - # feature_null_score = result.start_logits[0] + result.end_logits[0] - # if feature_null_score < score_null: - # score_null = feature_null_score - # min_null_feature_index = feature_index - # null_start_logit = result.start_logits[0] - # null_end_logit = result.end_logits[0] - for start_index in start_indexes: - for end_index in end_indexes: - # We could hypothetically create invalid predictions, e.g., predict - # that the start of the span is in the question. We throw out all - # invalid predictions. - if start_index >= len(feature.tokens): - continue - if end_index >= len(feature.tokens): - continue - if start_index not in feature.token_to_orig_map: - continue - if end_index not in feature.token_to_orig_map: - continue - if not feature.token_is_max_context.get(start_index, False): - continue - if end_index < start_index: - continue - length = end_index - start_index + 1 - if length > max_answer_length: - continue - prelim_predictions.append( - _PrelimPrediction( - feature_index=feature_index, - start_index=start_index, - end_index=end_index, - start_logit=result.start_logits[start_index], - end_logit=result.end_logits[end_index])) - - #if FLAGS.version_2_with_negative: - # prelim_predictions.append( - # _PrelimPrediction( - # feature_index=min_null_feature_index, - # start_index=0, - # end_index=0, - # start_logit=null_start_logit, - # end_logit=null_end_logit)) - prelim_predictions = sorted( - prelim_predictions, - key=lambda x: (x.start_logit + x.end_logit), - reverse=True) - - _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name - "NbestPrediction", ["text", "start_logit", "end_logit"]) - - seen_predictions = {} - nbest = [] - for pred in prelim_predictions: - if len(nbest) >= n_best_size: - break - feature = features[pred.feature_index] - if pred.start_index > 0: # this is a non-null prediction - tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] - orig_doc_start = feature.token_to_orig_map[pred.start_index] - orig_doc_end = feature.token_to_orig_map[pred.end_index] - orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] - tok_text = " ".join(tok_tokens) - - # De-tokenize WordPieces that have been split off. - tok_text = tok_text.replace(" ##", "") - tok_text = tok_text.replace("##", "") - - # Clean whitespace - tok_text = tok_text.strip() - tok_text = " ".join(tok_text.split()) - orig_text = " ".join(orig_tokens) - - final_text = get_final_text(tok_text, orig_text, do_lower_case) - if final_text in seen_predictions: - continue - - seen_predictions[final_text] = True - else: - final_text = "" - seen_predictions[final_text] = True - - nbest.append( - _NbestPrediction( - text=final_text, - start_logit=pred.start_logit, - end_logit=pred.end_logit)) - - # if we didn't inlude the empty option in the n-best, inlcude it - #if FLAGS.version_2_with_negative: - # if "" not in seen_predictions: - # nbest.append( - # _NbestPrediction( - # text="", start_logit=null_start_logit, - # end_logit=null_end_logit)) - # In very rare edge cases we could have no valid predictions. So we - # just create a nonce prediction in this case to avoid failure. - if not nbest: - nbest.append( - _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) - - assert len(nbest) >= 1 - - total_scores = [] - best_non_null_entry = None - for entry in nbest: - total_scores.append(entry.start_logit + entry.end_logit) - if not best_non_null_entry: - if entry.text: - best_non_null_entry = entry - - probs = _compute_softmax(total_scores) - - nbest_json = [] - for (i, entry) in enumerate(nbest): - output = collections.OrderedDict() - output["text"] = entry.text - output["probability"] = probs[i] - output["start_logit"] = entry.start_logit - output["end_logit"] = entry.end_logit - nbest_json.append(output) - - assert len(nbest_json) >= 1 - - #if not FLAGS.version_2_with_negative: - all_predictions[example.qas_id] = nbest_json[0]["text"] - #else: - # predict "" iff the null score - the score of best non-null > threshold - #score_diff = score_null - best_non_null_entry.start_logit - ( - # best_non_null_entry.end_logit) - #scores_diff_json[example.qas_id] = score_diff - #if score_diff > FLAGS.null_score_diff_threshold: - # all_predictions[example.qas_id] = "" - #else: - # all_predictions[example.qas_id] = best_non_null_entry.text - - all_nbest_json[example.qas_id] = nbest_json - - with tf.io.gfile.GFile(output_prediction_file, "w") as writer: - writer.write(json.dumps(all_predictions, indent=4) + "\n") - - with tf.io.gfile.GFile(output_nbest_file, "w") as writer: - writer.write(json.dumps(all_nbest_json, indent=4) + "\n") - - #if FLAGS.version_2_with_negative: - # with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer: - # writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64) + name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64) + + def _decode_record(record, name_to_features): + """Decodes a record to a TensorFlow example.""" + example = tf.io.parse_single_example( + serialized=record, features=name_to_features + ) + + # tf.Example only supports tf.int64, but the TPU only supports tf.int32. + # So cast all int64 to int32. + for name in list(example.keys()): + t = example[name] + if t.dtype == tf.int64: + t = tf.cast(t, dtype=tf.int32) + example[name] = t + + return example + + def input_fn(params): + """The actual input function.""" + batch_size = params["batch_size"] + + # For training, we want a lot of parallel reading and shuffling. + # For eval, we want no shuffling and parallel reading doesn't matter. + d = tf.data.TFRecordDataset(input_file) + if is_training: + d = d.repeat() + d = d.shuffle(buffer_size=100) + + d = d.apply( + tf.data.experimental.map_and_batch( + lambda record: _decode_record(record, name_to_features), + batch_size=batch_size, + drop_remainder=drop_remainder, + ) + ) + + return d + + return input_fn + + +RawResult = collections.namedtuple( + "RawResult", ["unique_id", "start_logits", "end_logits"] +) + + +def write_predictions( + all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, +): + """Write final predictions to the json file and log-odds of null if needed.""" + tf.compat.v1.logging.info("Writing predictions to: %s" % (output_prediction_file)) + tf.compat.v1.logging.info("Writing nbest to: %s" % (output_nbest_file)) + + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", + ["feature_index", "start_index", "end_index", "start_logit", "end_logit"], + ) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() + + for example_index, example in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + score_null = 1000000 # large and positive + min_null_feature_index = 0 # the paragraph slice with min mull score + null_start_logit = 0 # the start logit at the slice with min null score + null_end_logit = 0 # the end logit at the slice with min null score + for feature_index, feature in enumerate(features): + result = unique_id_to_result[feature.unique_id] + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + # if we could have irrelevant answers, get the min score of irrelevant + # if FLAGS.version_2_with_negative: + # feature_null_score = result.start_logits[0] + result.end_logits[0] + # if feature_null_score < score_null: + # score_null = feature_null_score + # min_null_feature_index = feature_index + # null_start_logit = result.start_logits[0] + # null_end_logit = result.end_logits[0] + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens): + continue + if end_index >= len(feature.tokens): + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index], + ) + ) + + # if FLAGS.version_2_with_negative: + # prelim_predictions.append( + # _PrelimPrediction( + # feature_index=min_null_feature_index, + # start_index=0, + # end_index=0, + # start_logit=null_start_logit, + # end_logit=null_end_logit)) + prelim_predictions = sorted( + prelim_predictions, + key=lambda x: (x.start_logit + x.end_logit), + reverse=True, + ) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"] + ) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + if pred.start_index > 0: # this is a non-null prediction + tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] + tok_text = " ".join(tok_tokens) + + # De-tokenize WordPieces that have been split off. + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text = get_final_text(tok_text, orig_text, do_lower_case) + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + else: + final_text = "" + seen_predictions[final_text] = True + + nbest.append( + _NbestPrediction( + text=final_text, + start_logit=pred.start_logit, + end_logit=pred.end_logit, + ) + ) + + # if we didn't inlude the empty option in the n-best, inlcude it + # if FLAGS.version_2_with_negative: + # if "" not in seen_predictions: + # nbest.append( + # _NbestPrediction( + # text="", start_logit=null_start_logit, + # end_logit=null_end_logit)) + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + assert len(nbest) >= 1 + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + if entry.text: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for i, entry in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + nbest_json.append(output) + + assert len(nbest_json) >= 1 + + # if not FLAGS.version_2_with_negative: + all_predictions[example.qas_id] = nbest_json[0]["text"] + # else: + # predict "" iff the null score - the score of best non-null > threshold + # score_diff = score_null - best_non_null_entry.start_logit - ( + # best_non_null_entry.end_logit) + # scores_diff_json[example.qas_id] = score_diff + # if score_diff > FLAGS.null_score_diff_threshold: + # all_predictions[example.qas_id] = "" + # else: + # all_predictions[example.qas_id] = best_non_null_entry.text + + all_nbest_json[example.qas_id] = nbest_json + + with tf.io.gfile.GFile(output_prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + + with tf.io.gfile.GFile(output_nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + + # if FLAGS.version_2_with_negative: + # with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer: + # writer.write(json.dumps(scores_diff_json, indent=4) + "\n") def get_final_text(pred_text, orig_text, do_lower_case): - """Project the tokenized prediction back to the original text.""" - - # When we created the data, we kept track of the alignment between original - # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So - # now `orig_text` contains the span of our original text corresponding to the - # span that we predicted. - # - # However, `orig_text` may contain extra characters that we don't want in - # our prediction. - # - # For example, let's say: - # pred_text = steve smith - # orig_text = Steve Smith's - # - # We don't want to return `orig_text` because it contains the extra "'s". - # - # We don't want to return `pred_text` because it's already been normalized - # (the SQuAD eval script also does punctuation stripping/lower casing but - # our tokenizer does additional normalization like stripping accent - # characters). - # - # What we really want to return is "Steve Smith". - # - # Therefore, we have to apply a semi-complicated alignment heruistic between - # `pred_text` and `orig_text` to get a character-to-charcter alignment. This - # can fail in certain cases in which case we just return `orig_text`. - - def _strip_spaces(text): - ns_chars = [] - ns_to_s_map = collections.OrderedDict() - for (i, c) in enumerate(text): - if c == " ": - continue - ns_to_s_map[len(ns_chars)] = i - ns_chars.append(c) - ns_text = "".join(ns_chars) - return (ns_text, ns_to_s_map) - - # We first tokenize `orig_text`, strip whitespace from the result - # and `pred_text`, and check if they are the same length. If they are - # NOT the same length, the heuristic has failed. If they are the same - # length, we assume the characters are one-to-one aligned. - tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) - - tok_text = " ".join(tokenizer.tokenize(orig_text)) - - start_position = tok_text.find(pred_text) - if start_position == -1: - if FLAGS.verbose_logging: - tf.compat.v1.logging.info( - "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) - return orig_text - end_position = start_position + len(pred_text) - 1 - - (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) - (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) - - if len(orig_ns_text) != len(tok_ns_text): - if FLAGS.verbose_logging: - tf.compat.v1.logging.info("Length not equal after stripping spaces: '%s' vs '%s'", - orig_ns_text, tok_ns_text) - return orig_text - - # We then project the characters in `pred_text` back to `orig_text` using - # the character-to-character alignment. - tok_s_to_ns_map = {} - for (i, tok_index) in six.iteritems(tok_ns_to_s_map): - tok_s_to_ns_map[tok_index] = i - - orig_start_position = None - if start_position in tok_s_to_ns_map: - ns_start_position = tok_s_to_ns_map[start_position] - if ns_start_position in orig_ns_to_s_map: - orig_start_position = orig_ns_to_s_map[ns_start_position] - - if orig_start_position is None: - if FLAGS.verbose_logging: - tf.compat.v1.logging.info("Couldn't map start position") - return orig_text - - orig_end_position = None - if end_position in tok_s_to_ns_map: - ns_end_position = tok_s_to_ns_map[end_position] - if ns_end_position in orig_ns_to_s_map: - orig_end_position = orig_ns_to_s_map[ns_end_position] - - if orig_end_position is None: - if FLAGS.verbose_logging: - tf.compat.v1.logging.info("Couldn't map end position") - return orig_text - - output_text = orig_text[orig_start_position:(orig_end_position + 1)] - return output_text + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heruistic between + # `pred_text` and `orig_text` to get a character-to-charcter alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for i, c in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if FLAGS.verbose_logging: + tf.compat.v1.logging.info( + "Unable to find text: '%s' in '%s'" % (pred_text, orig_text) + ) + return orig_text + end_position = start_position + len(pred_text) - 1 + + orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text) + tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if FLAGS.verbose_logging: + tf.compat.v1.logging.info( + "Length not equal after stripping spaces: '%s' vs '%s'", + orig_ns_text, + tok_ns_text, + ) + return orig_text + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for i, tok_index in six.iteritems(tok_ns_to_s_map): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if FLAGS.verbose_logging: + tf.compat.v1.logging.info("Couldn't map start position") + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if FLAGS.verbose_logging: + tf.compat.v1.logging.info("Couldn't map end position") + return orig_text + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text def _get_best_indexes(logits, n_best_size): - """Get the n-best logits from a list.""" - index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + """Get the n-best logits from a list.""" + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) - best_indexes = [] - for i in range(len(index_and_score)): - if i >= n_best_size: - break - best_indexes.append(index_and_score[i][0]) - return best_indexes + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes def _compute_softmax(scores): - """Compute softmax probability over raw logits.""" - if not scores: - return [] + """Compute softmax probability over raw logits.""" + if not scores: + return [] - max_score = None - for score in scores: - if max_score is None or score > max_score: - max_score = score + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score - exp_scores = [] - total_sum = 0.0 - for score in scores: - x = math.exp(score - max_score) - exp_scores.append(x) - total_sum += x + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x - probs = [] - for score in exp_scores: - probs.append(score / total_sum) - return probs + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs class FeatureWriter(object): - """Writes InputFeature to TF example file.""" + """Writes InputFeature to TF example file.""" - def __init__(self, filename, is_training): - self.filename = filename - self.is_training = is_training - self.num_features = 0 - self._writer = tf.io.TFRecordWriter(filename) + def __init__(self, filename, is_training): + self.filename = filename + self.is_training = is_training + self.num_features = 0 + self._writer = tf.io.TFRecordWriter(filename) - def process_feature(self, feature): - """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" - self.num_features += 1 + def process_feature(self, feature): + """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" + self.num_features += 1 - def create_int_feature(values): - feature = tf.train.Feature( - int64_list=tf.train.Int64List(value=list(values))) - return feature + def create_int_feature(values): + feature = tf.train.Feature( + int64_list=tf.train.Int64List(value=list(values)) + ) + return feature - features = collections.OrderedDict() - features["unique_ids"] = create_int_feature([feature.unique_id]) - features["input_ids"] = create_int_feature(feature.input_ids) - features["input_mask"] = create_int_feature(feature.input_mask) - features["segment_ids"] = create_int_feature(feature.segment_ids) + features = collections.OrderedDict() + features["unique_ids"] = create_int_feature([feature.unique_id]) + features["input_ids"] = create_int_feature(feature.input_ids) + features["input_mask"] = create_int_feature(feature.input_mask) + features["segment_ids"] = create_int_feature(feature.segment_ids) - if self.is_training: - features["start_positions"] = create_int_feature([feature.start_position]) - features["end_positions"] = create_int_feature([feature.end_position]) - impossible = 0 - if feature.is_impossible: - impossible = 1 - features["is_impossible"] = create_int_feature([impossible]) + if self.is_training: + features["start_positions"] = create_int_feature([feature.start_position]) + features["end_positions"] = create_int_feature([feature.end_position]) + impossible = 0 + if feature.is_impossible: + impossible = 1 + features["is_impossible"] = create_int_feature([impossible]) - tf_example = tf.train.Example(features=tf.train.Features(feature=features)) - self._writer.write(tf_example.SerializeToString()) + tf_example = tf.train.Example(features=tf.train.Features(feature=features)) + self._writer.write(tf_example.SerializeToString()) - def close(self): - self._writer.close() + def close(self): + self._writer.close() def validate_flags_or_throw(bert_config): - """Validate the input FLAGS or throw an exception.""" - tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, - FLAGS.init_checkpoint) - - if not FLAGS.do_train and not FLAGS.do_predict: - raise ValueError("At least one of `do_train` or `do_predict` must be True.") - - if FLAGS.do_train: - if not FLAGS.train_file: - raise ValueError( - "If `do_train` is True, then `train_file` must be specified.") - if FLAGS.do_predict: - if not FLAGS.predict_file: - raise ValueError( - "If `do_predict` is True, then `predict_file` must be specified.") - - if FLAGS.max_seq_length > bert_config.max_position_embeddings: - raise ValueError( - "Cannot use sequence length %d because the BERT model " - "was only trained up to sequence length %d" % - (FLAGS.max_seq_length, bert_config.max_position_embeddings)) - - if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: - raise ValueError( - "The max_seq_length (%d) must be greater than max_query_length " - "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) + """Validate the input FLAGS or throw an exception.""" + tokenization.validate_case_matches_checkpoint( + FLAGS.do_lower_case, FLAGS.init_checkpoint + ) + + if not FLAGS.do_train and not FLAGS.do_predict: + raise ValueError("At least one of `do_train` or `do_predict` must be True.") + + if FLAGS.do_train: + if not FLAGS.train_file: + raise ValueError( + "If `do_train` is True, then `train_file` must be specified." + ) + if FLAGS.do_predict: + if not FLAGS.predict_file: + raise ValueError( + "If `do_predict` is True, then `predict_file` must be specified." + ) + + if FLAGS.max_seq_length > bert_config.max_position_embeddings: + raise ValueError( + "Cannot use sequence length %d because the BERT model " + "was only trained up to sequence length %d" + % (FLAGS.max_seq_length, bert_config.max_position_embeddings) + ) + + if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: + raise ValueError( + "The max_seq_length (%d) must be greater than max_query_length " + "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length) + ) def main(_): - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - - bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) - - validate_flags_or_throw(bert_config) - - tf.io.gfile.makedirs(FLAGS.output_dir) - - tokenizer = tokenization.FullTokenizer( - vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.compat.v1.estimator.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, - model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - - train_examples = None - num_train_steps = None - num_warmup_steps = None - if FLAGS.do_train: - train_examples = read_squad_examples( - input_file=FLAGS.train_file, is_training=True) - num_train_steps = int( - len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) - num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) - - # Pre-shuffle the input to avoid having to make a very large shuffle - # buffer in in the `input_fn`. - rng = random.Random(12345) - rng.shuffle(train_examples) - - model_fn = model_fn_builder( - bert_config=bert_config, - init_checkpoint=FLAGS.init_checkpoint, - learning_rate=FLAGS.learning_rate, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps, - use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = tf.compat.v1.estimator.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, - model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - predict_batch_size=FLAGS.predict_batch_size) - - if FLAGS.do_train: - # We write to a temporary file to avoid storing very large constant tensors - # in memory. - train_writer = FeatureWriter( - filename=os.path.join(FLAGS.output_dir, "train.tf_record"), - is_training=True) - convert_examples_to_features( - examples=train_examples, - tokenizer=tokenizer, - max_seq_length=FLAGS.max_seq_length, - doc_stride=FLAGS.doc_stride, - max_query_length=FLAGS.max_query_length, - is_training=True, - output_fn=train_writer.process_feature) - train_writer.close() - - tf.compat.v1.logging.info("***** Running training *****") - tf.compat.v1.logging.info(" Num orig examples = %d", len(train_examples)) - tf.compat.v1.logging.info(" Num split examples = %d", train_writer.num_features) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) - tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) - del train_examples - - train_input_fn = input_fn_builder( - input_file=train_writer.filename, - seq_length=FLAGS.max_seq_length, - is_training=True, - drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) - - if FLAGS.do_predict: - eval_examples = read_squad_examples( - input_file=FLAGS.predict_file, is_training=False) - - eval_writer = FeatureWriter( - filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), - is_training=False) - eval_features = [] - - def append_feature(feature): - eval_features.append(feature) - eval_writer.process_feature(feature) - - convert_examples_to_features( - examples=eval_examples, - tokenizer=tokenizer, - max_seq_length=FLAGS.max_seq_length, - doc_stride=FLAGS.doc_stride, - max_query_length=FLAGS.max_query_length, - is_training=False, - output_fn=append_feature) - eval_writer.close() - - tf.compat.v1.logging.info("***** Running predictions *****") - tf.compat.v1.logging.info(" Num orig examples = %d", len(eval_examples)) - tf.compat.v1.logging.info(" Num split examples = %d", len(eval_features)) - tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size) - - all_results = [] - - predict_input_fn = input_fn_builder( - input_file=eval_writer.filename, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=False) - - # If running eval on the TPU, you will need to specify the number of - # steps. - all_results = [] - for result in estimator.predict( - predict_input_fn, yield_single_examples=True): - if len(all_results) % 1000 == 0: - tf.compat.v1.logging.info("Processing example: %d" % (len(all_results))) - unique_id = int(result["unique_ids"]) - start_logits = [float(x) for x in result["start_logits"].flat] - end_logits = [float(x) for x in result["end_logits"].flat] - all_results.append( - RawResult( - unique_id=unique_id, - start_logits=start_logits, - end_logits=end_logits)) - - output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json") - output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json") - output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json") - - write_predictions(eval_examples, eval_features, all_results, - FLAGS.n_best_size, FLAGS.max_answer_length, - FLAGS.do_lower_case, output_prediction_file, - output_nbest_file, output_null_log_odds_file) + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) + + validate_flags_or_throw(bert_config) + + tf.io.gfile.makedirs(FLAGS.output_dir) + + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case + ) + + tpu_cluster_resolver = None + if FLAGS.use_tpu and FLAGS.tpu_name: + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project + ) + + is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 + run_config = tf.compat.v1.estimator.tpu.RunConfig( + cluster=tpu_cluster_resolver, + master=FLAGS.master, + model_dir=FLAGS.output_dir, + save_checkpoints_steps=FLAGS.save_checkpoints_steps, + tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_tpu_cores, + per_host_input_for_training=is_per_host, + ), + ) + + train_examples = None + num_train_steps = None + num_warmup_steps = None + if FLAGS.do_train: + train_examples = read_squad_examples( + input_file=FLAGS.train_file, is_training=True + ) + num_train_steps = int( + len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs + ) + num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) + + # Pre-shuffle the input to avoid having to make a very large shuffle + # buffer in in the `input_fn`. + rng = random.Random(12345) + rng.shuffle(train_examples) + + model_fn = model_fn_builder( + bert_config=bert_config, + init_checkpoint=FLAGS.init_checkpoint, + learning_rate=FLAGS.learning_rate, + num_train_steps=num_train_steps, + num_warmup_steps=num_warmup_steps, + use_tpu=FLAGS.use_tpu, + use_one_hot_embeddings=FLAGS.use_tpu, + ) + + # If TPU is not available, this will fall back to normal Estimator on CPU + # or GPU. + estimator = tf.compat.v1.estimator.tpu.TPUEstimator( + use_tpu=FLAGS.use_tpu, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS.train_batch_size, + predict_batch_size=FLAGS.predict_batch_size, + ) + + if FLAGS.do_train: + # We write to a temporary file to avoid storing very large constant tensors + # in memory. + train_writer = FeatureWriter( + filename=os.path.join(FLAGS.output_dir, "train.tf_record"), is_training=True + ) + convert_examples_to_features( + examples=train_examples, + tokenizer=tokenizer, + max_seq_length=FLAGS.max_seq_length, + doc_stride=FLAGS.doc_stride, + max_query_length=FLAGS.max_query_length, + is_training=True, + output_fn=train_writer.process_feature, + ) + train_writer.close() + + tf.compat.v1.logging.info("***** Running training *****") + tf.compat.v1.logging.info(" Num orig examples = %d", len(train_examples)) + tf.compat.v1.logging.info( + " Num split examples = %d", train_writer.num_features + ) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) + tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) + del train_examples + + train_input_fn = input_fn_builder( + input_file=train_writer.filename, + seq_length=FLAGS.max_seq_length, + is_training=True, + drop_remainder=True, + ) + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + + if FLAGS.do_predict: + eval_examples = read_squad_examples( + input_file=FLAGS.predict_file, is_training=False + ) + + eval_writer = FeatureWriter( + filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), is_training=False + ) + eval_features = [] + + def append_feature(feature): + eval_features.append(feature) + eval_writer.process_feature(feature) + + convert_examples_to_features( + examples=eval_examples, + tokenizer=tokenizer, + max_seq_length=FLAGS.max_seq_length, + doc_stride=FLAGS.doc_stride, + max_query_length=FLAGS.max_query_length, + is_training=False, + output_fn=append_feature, + ) + eval_writer.close() + + tf.compat.v1.logging.info("***** Running predictions *****") + tf.compat.v1.logging.info(" Num orig examples = %d", len(eval_examples)) + tf.compat.v1.logging.info(" Num split examples = %d", len(eval_features)) + tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size) + + all_results = [] + + predict_input_fn = input_fn_builder( + input_file=eval_writer.filename, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=False, + ) + + # If running eval on the TPU, you will need to specify the number of + # steps. + all_results = [] + for result in estimator.predict(predict_input_fn, yield_single_examples=True): + if len(all_results) % 1000 == 0: + tf.compat.v1.logging.info("Processing example: %d" % (len(all_results))) + unique_id = int(result["unique_ids"]) + start_logits = [float(x) for x in result["start_logits"].flat] + end_logits = [float(x) for x in result["end_logits"].flat] + all_results.append( + RawResult( + unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits, + ) + ) + + output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json") + output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json") + output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json") + + write_predictions( + eval_examples, + eval_features, + all_results, + FLAGS.n_best_size, + FLAGS.max_answer_length, + FLAGS.do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + ) if __name__ == "__main__": - flags.mark_flag_as_required("vocab_file") - flags.mark_flag_as_required("bert_config_file") - flags.mark_flag_as_required("output_dir") - tf.compat.v1.app.run() - + flags.mark_flag_as_required("vocab_file") + flags.mark_flag_as_required("bert_config_file") + flags.mark_flag_as_required("output_dir") + tf.compat.v1.app.run() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization.py index 52c92adb..75d3c598 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization.py @@ -26,374 +26,384 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): - """Checks whether the casing config is consistent with the checkpoint name.""" - - # The casing has to be passed in by the user and there is no explicit check - # as to whether it matches the checkpoint. The casing information probably - # should have been stored in the bert_config.json file, but it's not, so - # we have to heuristically detect it to validate. - - if not init_checkpoint: - return - - m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) - if m is None: - return - - model_name = m.group(1) - - lower_models = [ - "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", - "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" - ] - - cased_models = [ - "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", - "multi_cased_L-12_H-768_A-12" - ] - - is_bad_config = False - if model_name in lower_models and not do_lower_case: - is_bad_config = True - actual_flag = "False" - case_name = "lowercased" - opposite_flag = "True" - - if model_name in cased_models and do_lower_case: - is_bad_config = True - actual_flag = "True" - case_name = "cased" - opposite_flag = "False" - - if is_bad_config: - raise ValueError( - "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " - "However, `%s` seems to be a %s model, so you " - "should pass in `--do_lower_case=%s` so that the fine-tuning matches " - "how the model was pre-training. If this error is wrong, please " - "just comment out this check." % (actual_flag, init_checkpoint, - model_name, case_name, opposite_flag)) + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", + "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", + "chinese_L-12_H-768_A-12", + ] + + cased_models = [ + "cased_L-12_H-768_A-12", + "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12", + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." + % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag) + ) def convert_to_unicode(text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode("utf-8", "ignore") - elif isinstance(text, unicode): - return text - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") + raise ValueError("Not running on Python2 or Python 3?") def printable_text(text): - """Returns text encoded in a way suitable for print or `tf.logging`.""" - - # These functions want `str` for both Python2 and Python3, but in one case - # it's a Unicode string and in the other it's a byte string. - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text - elif isinstance(text, unicode): - return text.encode("utf-8") + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") + raise ValueError("Not running on Python2 or Python 3?") def load_vocab(vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() - index = 0 - with tf.io.gfile.GFile(vocab_file, "r") as reader: - while True: - token = convert_to_unicode(reader.readline()) - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 - return vocab + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with tf.io.gfile.GFile(vocab_file, "r") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab def convert_by_vocab(vocab, items): - """Converts a sequence of [tokens|ids] using the vocab.""" - output = [] - for item in items: - output.append(vocab[item]) - return output + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output def convert_tokens_to_ids(vocab, tokens): - return convert_by_vocab(vocab, tokens) + return convert_by_vocab(vocab, tokens) def convert_ids_to_tokens(inv_vocab, ids): - return convert_by_vocab(inv_vocab, ids) + return convert_by_vocab(inv_vocab, ids) def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens class FullTokenizer(object): - """Runs end-to-end tokenziation.""" + """Runs end-to-end tokenziation.""" - def __init__(self, vocab_file, do_lower_case=True): - self.vocab = load_vocab(vocab_file) - self.inv_vocab = {v: k for k, v in self.vocab.items()} - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) - def tokenize(self, text): - split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) - return split_tokens + return split_tokens - def convert_tokens_to_ids(self, tokens): - return convert_by_vocab(self.vocab, tokens) + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) - def convert_ids_to_tokens(self, ids): - return convert_by_vocab(self.inv_vocab, ids) + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) class BasicTokenizer(object): - """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - - def __init__(self, do_lower_case=True): - """Constructs a BasicTokenizer. - - Args: - do_lower_case: Whether to lower case the input. - """ - self.do_lower_case = do_lower_case - - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = convert_to_unicode(text) - text = self._clean_text(text) - - # This was added on November 1st, 2018 for the multilingual and Chinese - # models. This is also applied to the English models now, but it doesn't - # matter since the English models were not trained on any Chinese data - # and generally don't have any Chinese data in them (there are Chinese - # characters in the vocabulary because Wikipedia does have some Chinese - # words in the English Wikipedia.). - text = self._tokenize_chinese_chars(text) - - orig_tokens = whitespace_tokenize(text) - split_tokens = [] - for token in orig_tokens: - if self.do_lower_case: - token = token.lower() - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) # + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) # + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) class WordpieceTokenizer(object): - """Runs WordPiece tokenziation.""" - - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): - self.vocab = vocab - self.unk_token = unk_token - self.max_input_chars_per_word = max_input_chars_per_word - - def tokenize(self, text): - """Tokenizes a piece of text into its word pieces. - - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. - - For example: - input = "unaffable" - output = ["un", "##aff", "##able"] - - Args: - text: A single token or whitespace separated tokens. This should have - already been passed through `BasicTokenizer. - - Returns: - A list of wordpiece tokens. - """ - - text = convert_to_unicode(text) - - output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) - continue - - is_bad = False - start = 0 - sub_tokens = [] - while start < len(chars): - end = len(chars) - cur_substr = None - while start < end: - substr = "".join(chars[start:end]) - if start > 0: - substr = "##" + substr - if substr in self.vocab: - cur_substr = substr - break - end -= 1 - if cur_substr is None: - is_bad = True - break - sub_tokens.append(cur_substr) - start = end - - if is_bad: - output_tokens.append(self.unk_token) - else: - output_tokens.extend(sub_tokens) - return output_tokens + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ( + (cp >= 33 and cp <= 47) + or (cp >= 58 and cp <= 64) + or (cp >= 91 and cp <= 96) + or (cp >= 123 and cp <= 126) + ): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization_test.py b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization_test.py index 0afaedd2..1b8bee0c 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization_test.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/bert_v2/tokenization_test.py @@ -25,113 +25,147 @@ class TokenizationTest(tf.test.TestCase): - def test_full_tokenizer(self): - vocab_tokens = [ - "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", - "##ing", "," - ] - with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: - if six.PY2: - vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - else: - vocab_writer.write("".join( - [x + "\n" for x in vocab_tokens]).encode("utf-8")) - - vocab_file = vocab_writer.name - - tokenizer = tokenization.FullTokenizer(vocab_file) - os.unlink(vocab_file) - - tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") - self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) - - self.assertAllEqual( - tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) - - def test_chinese(self): - tokenizer = tokenization.BasicTokenizer() - - self.assertAllEqual( - tokenizer.tokenize(u"ah\u535A\u63A8zz"), - [u"ah", u"\u535A", u"\u63A8", u"zz"]) - - def test_basic_tokenizer_lower(self): - tokenizer = tokenization.BasicTokenizer(do_lower_case=True) - - self.assertAllEqual( - tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), - ["hello", "!", "how", "are", "you", "?"]) - self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) - - def test_basic_tokenizer_no_lower(self): - tokenizer = tokenization.BasicTokenizer(do_lower_case=False) - - self.assertAllEqual( - tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), - ["HeLLo", "!", "how", "Are", "yoU", "?"]) - - def test_wordpiece_tokenizer(self): - vocab_tokens = [ - "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", - "##ing" - ] - - vocab = {} - for (i, token) in enumerate(vocab_tokens): - vocab[token] = i - tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) - - self.assertAllEqual(tokenizer.tokenize(""), []) - - self.assertAllEqual( - tokenizer.tokenize("unwanted running"), - ["un", "##want", "##ed", "runn", "##ing"]) - - self.assertAllEqual( - tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) - - def test_convert_tokens_to_ids(self): - vocab_tokens = [ - "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", - "##ing" - ] - - vocab = {} - for (i, token) in enumerate(vocab_tokens): - vocab[token] = i - - self.assertAllEqual( - tokenization.convert_tokens_to_ids( - vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) - - def test_is_whitespace(self): - self.assertTrue(tokenization._is_whitespace(u" ")) - self.assertTrue(tokenization._is_whitespace(u"\t")) - self.assertTrue(tokenization._is_whitespace(u"\r")) - self.assertTrue(tokenization._is_whitespace(u"\n")) - self.assertTrue(tokenization._is_whitespace(u"\u00A0")) - - self.assertFalse(tokenization._is_whitespace(u"A")) - self.assertFalse(tokenization._is_whitespace(u"-")) - - def test_is_control(self): - self.assertTrue(tokenization._is_control(u"\u0005")) - - self.assertFalse(tokenization._is_control(u"A")) - self.assertFalse(tokenization._is_control(u" ")) - self.assertFalse(tokenization._is_control(u"\t")) - self.assertFalse(tokenization._is_control(u"\r")) - self.assertFalse(tokenization._is_control(u"\U0001F4A9")) - - def test_is_punctuation(self): - self.assertTrue(tokenization._is_punctuation(u"-")) - self.assertTrue(tokenization._is_punctuation(u"$")) - self.assertTrue(tokenization._is_punctuation(u"`")) - self.assertTrue(tokenization._is_punctuation(u".")) - - self.assertFalse(tokenization._is_punctuation(u"A")) - self.assertFalse(tokenization._is_punctuation(u" ")) + def test_full_tokenizer(self): + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + ] + with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: + if six.PY2: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + else: + vocab_writer.write( + "".join([x + "\n" for x in vocab_tokens]).encode("utf-8") + ) + + vocab_file = vocab_writer.name + + tokenizer = tokenization.FullTokenizer(vocab_file) + os.unlink(vocab_file) + + tokens = tokenizer.tokenize("UNwant\u00e9d,running") + self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + + self.assertAllEqual( + tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9] + ) + + def test_chinese(self): + tokenizer = tokenization.BasicTokenizer() + + self.assertAllEqual( + tokenizer.tokenize("ah\u535a\u63a8zz"), ["ah", "\u535a", "\u63a8", "zz"] + ) + + def test_basic_tokenizer_lower(self): + tokenizer = tokenization.BasicTokenizer(do_lower_case=True) + + self.assertAllEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), + ["hello", "!", "how", "are", "you", "?"], + ) + self.assertAllEqual(tokenizer.tokenize("H\u00e9llo"), ["hello"]) + + def test_basic_tokenizer_no_lower(self): + tokenizer = tokenization.BasicTokenizer(do_lower_case=False) + + self.assertAllEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), + ["HeLLo", "!", "how", "Are", "yoU", "?"], + ) + + def test_wordpiece_tokenizer(self): + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ] + + vocab = {} + for i, token in enumerate(vocab_tokens): + vocab[token] = i + tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) + + self.assertAllEqual(tokenizer.tokenize(""), []) + + self.assertAllEqual( + tokenizer.tokenize("unwanted running"), + ["un", "##want", "##ed", "runn", "##ing"], + ) + + self.assertAllEqual( + tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"] + ) + + def test_convert_tokens_to_ids(self): + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ] + + vocab = {} + for i, token in enumerate(vocab_tokens): + vocab[token] = i + + self.assertAllEqual( + tokenization.convert_tokens_to_ids( + vocab, ["un", "##want", "##ed", "runn", "##ing"] + ), + [7, 4, 5, 8, 9], + ) + + def test_is_whitespace(self): + self.assertTrue(tokenization._is_whitespace(" ")) + self.assertTrue(tokenization._is_whitespace("\t")) + self.assertTrue(tokenization._is_whitespace("\r")) + self.assertTrue(tokenization._is_whitespace("\n")) + self.assertTrue(tokenization._is_whitespace("\u00a0")) + + self.assertFalse(tokenization._is_whitespace("A")) + self.assertFalse(tokenization._is_whitespace("-")) + + def test_is_control(self): + self.assertTrue(tokenization._is_control("\u0005")) + + self.assertFalse(tokenization._is_control("A")) + self.assertFalse(tokenization._is_control(" ")) + self.assertFalse(tokenization._is_control("\t")) + self.assertFalse(tokenization._is_control("\r")) + self.assertFalse(tokenization._is_control("\U0001f4a9")) + + def test_is_punctuation(self): + self.assertTrue(tokenization._is_punctuation("-")) + self.assertTrue(tokenization._is_punctuation("$")) + self.assertTrue(tokenization._is_punctuation("`")) + self.assertTrue(tokenization._is_punctuation(".")) + + self.assertFalse(tokenization._is_punctuation("A")) + self.assertFalse(tokenization._is_punctuation(" ")) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/compile.sh b/sample_programs/ml_sample_programs/nlp_models/bert/compile.sh index 84808dc3..2aa03488 100755 --- a/sample_programs/ml_sample_programs/nlp_models/bert/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/bert/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp bertsquad-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp bertsquad-12.onnx mlir-translate -mlir-to-llvmir bertsquad-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/createJsonOut.py b/sample_programs/ml_sample_programs/nlp_models/bert/createJsonOut.py index 966d2c1a..60fd853e 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/createJsonOut.py +++ b/sample_programs/ml_sample_programs/nlp_models/bert/createJsonOut.py @@ -8,13 +8,12 @@ import shutil import collections - ROOT = os.getcwd() -SQUAD_DIR = os.path.join(ROOT, 'squad-1.1') -OUT = os.path.join(ROOT, 'out') -BERT_BASE_DIR = os.path.join(ROOT, 'uncased_L-12_H-768_A-12') -LLFI_OUT = os.path.join(ROOT, 'llfi') -PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') +SQUAD_DIR = os.path.join(ROOT, "squad-1.1") +OUT = os.path.join(ROOT, "out") +BERT_BASE_DIR = os.path.join(ROOT, "uncased_L-12_H-768_A-12") +LLFI_OUT = os.path.join(ROOT, "llfi") +PROG_OUT = os.path.join(LLFI_OUT, "prog_output") # define some constants used by the model EVAL_BATCH_SIZE = 8 @@ -23,23 +22,29 @@ MAX_ANSWER_LENGTH = 30 MAX_QUERY_LENGTH = 64 DOC_STRIDE = 128 -VOCAB_FILE = os.path.join(BERT_BASE_DIR, 'vocab.txt') +VOCAB_FILE = os.path.join(BERT_BASE_DIR, "vocab.txt") + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def append_feature(eval_features, eval_writer, feature): eval_features.append(feature) eval_writer.process_feature(feature) + def main(inpSample): tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE, do_lower_case=True) - eval_examples = run_squad.read_squad_examples(input_file=os.path.join(SQUAD_DIR, "dev-v1.1.json"), is_training=False) - eval_writer = run_squad.FeatureWriter(filename=os.path.join(OUT, "eval.tf_record"), is_training=False) + eval_examples = run_squad.read_squad_examples( + input_file=os.path.join(SQUAD_DIR, "dev-v1.1.json"), is_training=False + ) + eval_writer = run_squad.FeatureWriter( + filename=os.path.join(OUT, "eval.tf_record"), is_training=False + ) eval_features = [] - run_squad.convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, @@ -47,7 +52,8 @@ def main(inpSample): doc_stride=DOC_STRIDE, max_query_length=MAX_QUERY_LENGTH, is_training=False, - output_fn=partial(append_feature, eval_features, eval_writer)) + output_fn=partial(append_feature, eval_features, eval_writer), + ) eval_writer.close() # Read LLTFI output from llfi/prog_output and add it to 'listResArr' @@ -58,26 +64,28 @@ def main(inpSample): listResArr = [] for filename in txtfiles: - resforSingleInput = [] - with open(filename, "r") as read_file: - resultJson = json.load(read_file) + resforSingleInput = [] + with open(filename, "r") as read_file: + resultJson = json.load(read_file) - for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) - listResArr.append(resforSingleInput) + for key, value in resultJson.items(): + resforSingleInput.append(value["Data"]) + listResArr.append(resforSingleInput) - RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) - pathOutput = os.path.join(OUT, 'onnx-pred') - pathnBestOutput = os.path.join(OUT, 'onnx-nbest-pred') + RawResult = collections.namedtuple( + "RawResult", ["unique_id", "start_logits", "end_logits"] + ) + pathOutput = os.path.join(OUT, "onnx-pred") + pathnBestOutput = os.path.join(OUT, "onnx-nbest-pred") # Reference outputs refOutFile = f"ref_outputs_json/pred_inp_{str(inpSample)}/onnx_predictions.json" with open(refOutFile, "r") as read_file: - refOut = json.load(read_file) + refOut = json.load(read_file) - #refOutNbestFile = f"ref_outputs_json/pred_inp_{str(inpSample)}/onnx_nbest_predictions.json" - #with open(refOutNbestFile, "r") as read_file: - #refOutNbest = json.load(read_file) + # refOutNbestFile = f"ref_outputs_json/pred_inp_{str(inpSample)}/onnx_nbest_predictions.json" + # with open(refOutNbestFile, "r") as read_file: + # refOutNbest = json.load(read_file) # field 'unique_id' if inpSample == 8: @@ -89,45 +97,63 @@ def main(inpSample): # Get output from 'listResArr' array and convert to text for index in range(len(listResArr)): - all_results = [] - all_results.append(RawResult(unique_id=uniqueId, start_logits=listResArr[index][1], end_logits=listResArr[index][0])) - - out_predictions_json = f"onnx_predictions{index}.json" - out_nbestpredictions_json = f"onnx_nbest_predictions{index}.json" - out_null_odds_json = f"onnx_null_odds{index}.json" - run_squad.write_predictions(eval_examples[:1], eval_features[:1], all_results, - N_BEST_SIZE, MAX_ANSWER_LENGTH, True, - os.path.join(OUT, out_predictions_json), - os.path.join(OUT, out_nbestpredictions_json), - os.path.join(OUT, out_null_odds_json)) - - # Compare the generated outputs(final prediction and N best predictions) with reference outputs - with open(os.path.join(OUT, out_predictions_json), "r") as read_file: - out = json.load(read_file) - - if out != refOut: - print(f"FI in run number {index} led to incorrect output") - else: - print(f"FI in run number {index} led to correct output") - - #with open(os.path.join(OUT, out_nbestpredictions_json), "r") as read_file: - #outNbest = json.load(read_file) - - #if outNbest != refOutNbest: - #print(f"FI in run number {index} led to incorrect output(N best)") - - - # Save the text output - if not os.path.isdir(pathOutput): - os.mkdir(pathOutput) - if not os.path.isdir(pathnBestOutput): - os.mkdir(pathnBestOutput) - - ONNX_PRED = os.path.join(OUT, 'onnx-pred') - ONNX_NBEST_PRED = os.path.join(OUT, 'onnx-nbest-pred') + all_results = [] + all_results.append( + RawResult( + unique_id=uniqueId, + start_logits=listResArr[index][1], + end_logits=listResArr[index][0], + ) + ) + + out_predictions_json = f"onnx_predictions{index}.json" + out_nbestpredictions_json = f"onnx_nbest_predictions{index}.json" + out_null_odds_json = f"onnx_null_odds{index}.json" + run_squad.write_predictions( + eval_examples[:1], + eval_features[:1], + all_results, + N_BEST_SIZE, + MAX_ANSWER_LENGTH, + True, + os.path.join(OUT, out_predictions_json), + os.path.join(OUT, out_nbestpredictions_json), + os.path.join(OUT, out_null_odds_json), + ) + + # Compare the generated outputs(final prediction and N best predictions) with reference outputs + with open(os.path.join(OUT, out_predictions_json), "r") as read_file: + out = json.load(read_file) + + if out != refOut: + print(f"FI in run number {index} led to incorrect output") + else: + print(f"FI in run number {index} led to correct output") + + # with open(os.path.join(OUT, out_nbestpredictions_json), "r") as read_file: + # outNbest = json.load(read_file) + + # if outNbest != refOutNbest: + # print(f"FI in run number {index} led to incorrect output(N best)") + + # Save the text output + if not os.path.isdir(pathOutput): + os.mkdir(pathOutput) + if not os.path.isdir(pathnBestOutput): + os.mkdir(pathnBestOutput) + + ONNX_PRED = os.path.join(OUT, "onnx-pred") + ONNX_NBEST_PRED = os.path.join(OUT, "onnx-nbest-pred") + + shutil.move( + os.path.join(OUT, out_predictions_json), + os.path.join(ONNX_PRED, out_predictions_json), + ) + shutil.move( + os.path.join(OUT, out_nbestpredictions_json), + os.path.join(ONNX_NBEST_PRED, out_nbestpredictions_json), + ) - shutil.move(os.path.join(OUT, out_predictions_json), os.path.join(ONNX_PRED, out_predictions_json)) - shutil.move(os.path.join(OUT, out_nbestpredictions_json), os.path.join(ONNX_NBEST_PRED, out_nbestpredictions_json)) if __name__ == "__main__": inpSample = int(sys.argv[1]) diff --git a/sample_programs/ml_sample_programs/nlp_models/bert/input.c b/sample_programs/ml_sample_programs/nlp_models/bert/input.c index ab3dc159..93902266 100644 --- a/sample_programs/ml_sample_programs/nlp_models/bert/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/bert/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/compile.sh b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/compile.sh index 403cb5a5..356033f8 100755 --- a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/compile.sh @@ -11,7 +11,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp gpt2-lm-head-10.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp gpt2-lm-head-10.onnx mlir-translate -mlir-to-llvmir gpt2-lm-head-10.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/createPredFile.py b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/createPredFile.py index 440fff21..3e37f9bf 100644 --- a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/createPredFile.py +++ b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/createPredFile.py @@ -7,8 +7,8 @@ import tensorflow as tf ROOT = os.getcwd() -LLFI_OUT = os.path.join(ROOT, 'llfi') -PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') +LLFI_OUT = os.path.join(ROOT, "llfi") +PROG_OUT = os.path.join(LLFI_OUT, "prog_output") inputs = [] inputs.append("This chair is white and the table is") @@ -16,22 +16,38 @@ inputs.append("I am a doctor and I work at a") inputs.append("I like playing with my") inputs.append("A rose by any other name would smell as") -inputs.append("US-led coalition air strikes on a jail run by the Islamic State group in eastern Syria killed") -inputs.append("A magazine supplement with an image of Adolf Hitler and the title 'The Unreadable Book' is pictured in") -inputs.append("Winter isn't done with us yet. Ottawa can expect another 10 to 15 centimetres of") -inputs.append("Refined mansion tax proposal being fed into debate on abolishing 50p tax rate for those earning more than") -inputs.append("Ghazala Khan, the mother of a fallen U.S. soldier of Muslim faith, is responding to Donald Trump’s") +inputs.append( + "US-led coalition air strikes on a jail run by the Islamic State group in eastern Syria killed" +) +inputs.append( + "A magazine supplement with an image of Adolf Hitler and the title 'The Unreadable Book' is pictured in" +) +inputs.append( + "Winter isn't done with us yet. Ottawa can expect another 10 to 15 centimetres of" +) +inputs.append( + "Refined mansion tax proposal being fed into debate on abolishing 50p tax rate for those earning more than" +) +inputs.append( + "Ghazala Khan, the mother of a fallen U.S. soldier of Muslim faith, is responding to Donald Trump’s" +) -tokenizer = GPT2Tokenizer.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(inpSample): # Get LLTFI outputs in listResArr listResArr = [] - list_of_files = sorted( filter( lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), - os.listdir(PROG_OUT) ), key=lltfi_sort ) + list_of_files = sorted( + filter( + lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), os.listdir(PROG_OUT) + ), + key=lltfi_sort, + ) for i in range(len(list_of_files)): list_of_files[i] = os.path.join(PROG_OUT, list_of_files[i]) @@ -42,14 +58,14 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) + resforSingleInput.append(value["Data"]) listResArr.append(resforSingleInput) list_output_np = [] # Reshape the output and store as numpy array for elem in listResArr: output_np = np.asarray(elem[0]) - output_np = output_np.reshape(1,1,-1,50257) # Shape (1,1,len,50257) + output_np = output_np.reshape(1, 1, -1, 50257) # Shape (1,1,len,50257) list_output_np.append(output_np) tokens = np.array(tokenizer.encode(inputs[inpSample])) @@ -58,18 +74,19 @@ def main(inpSample): listPreds = [] for elemIndex in range(len(list_output_np)): input_to_model = tf.convert_to_tensor( - [[tokenizer.encode(inputs[inpSample], add_special_tokens=True)]]) # Shape [1, 1, len] - prev = input_to_model # [1, 1, len] Set prev as input in the first step - prev = prev[0] # [1, len] + [[tokenizer.encode(inputs[inpSample], add_special_tokens=True)]] + ) # Shape [1, 1, len] + prev = input_to_model # [1, 1, len] Set prev as input in the first step + prev = prev[0] # [1, len] output = prev logits = list_output_np[elemIndex][0] logits = logits[:, -1, :] logits = tf.convert_to_tensor(logits) - log_probs = tf.nn.softmax(logits,axis=-1) + log_probs = tf.nn.softmax(logits, axis=-1) prob, prev = tf.math.top_k(log_probs, k=10) output = tf.concat((output, prev), axis=1) - final_output = output[:, len(tokens):].numpy().tolist() + final_output = output[:, len(tokens) :].numpy().tolist() listPreds.append(f"Run #{elemIndex} Predictions and Probability:\n") for opNum in range(0, len(final_output[0])): @@ -77,7 +94,7 @@ def main(inpSample): opProb = prob.numpy()[0][opNum] listPreds.append(f"{text} : {opProb}\n") - myfile = open('prediction/PredResult.txt', 'w') + myfile = open("prediction/PredResult.txt", "w") myfile.writelines(listPreds) myfile.close() diff --git a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/helper_scripts/exportInputsAsTensors.py b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/helper_scripts/exportInputsAsTensors.py index d72914a0..00529f77 100644 --- a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/helper_scripts/exportInputsAsTensors.py +++ b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/helper_scripts/exportInputsAsTensors.py @@ -9,17 +9,27 @@ inputs.append("I am a doctor and I work at a") inputs.append("I like playing with my") inputs.append("A rose by any other name would smell as") -inputs.append("US-led coalition air strikes on a jail run by the Islamic State group in eastern Syria killed") -inputs.append("A magazine supplement with an image of Adolf Hitler and the title 'The Unreadable Book' is pictured in") -inputs.append("Winter isn't done with us yet. Ottawa can expect another 10 to 15 centimetres of") -inputs.append("Refined mansion tax proposal being fed into debate on abolishing 50p tax rate for those earning more than") -inputs.append("Ghazala Khan, the mother of a fallen U.S. soldier of Muslim faith, is responding to Donald Trump’s") +inputs.append( + "US-led coalition air strikes on a jail run by the Islamic State group in eastern Syria killed" +) +inputs.append( + "A magazine supplement with an image of Adolf Hitler and the title 'The Unreadable Book' is pictured in" +) +inputs.append( + "Winter isn't done with us yet. Ottawa can expect another 10 to 15 centimetres of" +) +inputs.append( + "Refined mansion tax proposal being fed into debate on abolishing 50p tax rate for those earning more than" +) +inputs.append( + "Ghazala Khan, the mother of a fallen U.S. soldier of Muslim faith, is responding to Donald Trump’s" +) -tokenizer = GPT2Tokenizer.from_pretrained('gpt2') +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") for index in range(0, len(inputs)): - tokens = np.array(tokenizer.encode(inputs[index])) # Shape : (len,) - input_arr = tokens.reshape(1,1,-1) # Shape : (1,1,len) - input_tensor = numpy_helper.from_array(input_arr) - with open("input_{}.pb".format(index), 'wb') as file: - file.write(input_tensor.SerializeToString()) + tokens = np.array(tokenizer.encode(inputs[index])) # Shape : (len,) + input_arr = tokens.reshape(1, 1, -1) # Shape : (1,1,len) + input_tensor = numpy_helper.from_array(input_arr) + with open("input_{}.pb".format(index), "wb") as file: + file.write(input_tensor.SerializeToString()) diff --git a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/input.c b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/input.c index 9f4dba48..5525e637 100644 --- a/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/gpt2-lm-head-10/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/gpt2/compile.sh b/sample_programs/ml_sample_programs/nlp_models/gpt2/compile.sh index b9c214ef..4010af19 100755 --- a/sample_programs/ml_sample_programs/nlp_models/gpt2/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/gpt2/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp gpt2-10.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp gpt2-10.onnx mlir-translate -mlir-to-llvmir gpt2-10.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/gpt2/input.c b/sample_programs/ml_sample_programs/nlp_models/gpt2/input.c index 9f4dba48..5525e637 100644 --- a/sample_programs/ml_sample_programs/nlp_models/gpt2/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/gpt2/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/compile.sh b/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/compile.sh index 1572fe36..4caabfbd 100755 --- a/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp roberta-base-11.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp roberta-base-11.onnx mlir-translate -mlir-to-llvmir roberta-base-11.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/input.c b/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/input.c index 39d5aa20..fb3f6287 100644 --- a/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/roberta-base-11/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/compile.sh b/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/compile.sh index afabaaf7..97e360e3 100755 --- a/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp roberta-sequence-classification-9.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp roberta-sequence-classification-9.onnx mlir-translate -mlir-to-llvmir roberta-sequence-classification-9.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/createPredFile.py b/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/createPredFile.py index 35afa540..66b51ed8 100644 --- a/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/createPredFile.py +++ b/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/createPredFile.py @@ -4,41 +4,47 @@ import json ROOT = os.getcwd() -LLFI_OUT = os.path.join(ROOT, 'llfi') -PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') +LLFI_OUT = os.path.join(ROOT, "llfi") +PROG_OUT = os.path.join(LLFI_OUT, "prog_output") + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) def main(): - # Get LLTFI outputs in listResArr - listResArr = [] - list_of_files = sorted( filter( lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), - os.listdir(PROG_OUT) ), key=lltfi_sort ) - - for i in range(len(list_of_files)): - list_of_files[i] = os.path.join(PROG_OUT, list_of_files[i]) - - for filename in list_of_files: - with open(filename, "r") as read_file: - resultJson = json.load(read_file) - - for key, value in resultJson.items(): - listResArr.append(value['Data']) - - listPreds = [] - # Get Sentiment (positive/negative) prediction - for resIndex in range(len(listResArr)): - elem = listResArr[resIndex] - pred = np.argmax(elem) - if(pred == 0): - listPreds.append(f"Run #{resIndex} Prediction: Negative\n") - elif(pred == 1): - listPreds.append(f"Run #{resIndex} Prediction: Positive\n") - - myfile = open('prediction/PredResult.txt', 'w') - myfile.writelines(listPreds) - myfile.close() + # Get LLTFI outputs in listResArr + listResArr = [] + list_of_files = sorted( + filter( + lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), os.listdir(PROG_OUT) + ), + key=lltfi_sort, + ) + + for i in range(len(list_of_files)): + list_of_files[i] = os.path.join(PROG_OUT, list_of_files[i]) + + for filename in list_of_files: + with open(filename, "r") as read_file: + resultJson = json.load(read_file) + + for key, value in resultJson.items(): + listResArr.append(value["Data"]) + + listPreds = [] + # Get Sentiment (positive/negative) prediction + for resIndex in range(len(listResArr)): + elem = listResArr[resIndex] + pred = np.argmax(elem) + if pred == 0: + listPreds.append(f"Run #{resIndex} Prediction: Negative\n") + elif pred == 1: + listPreds.append(f"Run #{resIndex} Prediction: Positive\n") + + myfile = open("prediction/PredResult.txt", "w") + myfile.writelines(listPreds) + myfile.close() if __name__ == "__main__": diff --git a/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/input.c b/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/input.c index 39d5aa20..fb3f6287 100644 --- a/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/roberta-seq-classification-9/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/compile.sh b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/compile.sh index 498b8760..24e20de1 100755 --- a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/compile.sh @@ -17,7 +17,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp t5-decoder-with-lm-head-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp t5-decoder-with-lm-head-12.onnx mlir-translate -mlir-to-llvmir t5-decoder-with-lm-head-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/createPredFile.py b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/createPredFile.py index 372e284d..478033ef 100644 --- a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/createPredFile.py +++ b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/createPredFile.py @@ -14,19 +14,35 @@ import sys, pdb prompts = [] -prompts.append('translate English to German: I declare resumed the session of the European Parliament.') -prompts.append('translate English to German: Statements by the President') -prompts.append('translate English to French: Statements by the President') -prompts.append('translate English to French: I declare resumed the session of the European Parliament.') -prompts.append('translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake') -prompts.append('translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake') -prompts.append('''translate English to German: (The House rose and observed a minute's silence)''') -prompts.append('''translate English to French: (The House rose and observed a minute's silence)''') -prompts.append('translate English to German: I should like, on behalf of the European Parliament, to express') -prompts.append('translate English to French: I should like, on behalf of the European Parliament, to express') +prompts.append( + "translate English to German: I declare resumed the session of the European Parliament." +) +prompts.append("translate English to German: Statements by the President") +prompts.append("translate English to French: Statements by the President") +prompts.append( + "translate English to French: I declare resumed the session of the European Parliament." +) +prompts.append( + "translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake" +) +prompts.append( + "translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake" +) +prompts.append( + """translate English to German: (The House rose and observed a minute's silence)""" +) +prompts.append( + """translate English to French: (The House rose and observed a minute's silence)""" +) +prompts.append( + "translate English to German: I should like, on behalf of the European Parliament, to express" +) +prompts.append( + "translate English to French: I should like, on behalf of the European Parliament, to express" +) prev_dec_out = [] -prev_dec_out.append(torch.tensor([[0, 1674]])) #Ich +prev_dec_out.append(torch.tensor([[0, 1674]])) # Ich prev_dec_out.append(torch.tensor([[0, 28019]])) prev_dec_out.append(torch.tensor([[0, 4829]])) prev_dec_out.append(torch.tensor([[0, 1022]])) @@ -50,27 +66,37 @@ prev_out.append("Au") ROOT = os.getcwd() -LLFI_OUT = os.path.join(ROOT, 'llfi') -PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') +LLFI_OUT = os.path.join(ROOT, "llfi") +PROG_OUT = os.path.join(LLFI_OUT, "prog_output") -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) # Ref: https://github.com/abelriboulot/onnxt5 ### Helper class to get decoder output class GenerativeT5_decoder(torch.nn.Module): - """ Code ref: https://github.com/abelriboulot/onnxt5 - Args: - encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - decoder_with_lm_head: decoder with language model head on top. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - tokenizer: huggingface tokenizer - onnx (bool): whether to use onnx or the default pytorch - cuda (bool): whether to use cuda or the cpu + """Code ref: https://github.com/abelriboulot/onnxt5 + Args: + encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + decoder_with_lm_head: decoder with language model head on top. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + tokenizer: huggingface tokenizer + onnx (bool): whether to use onnx or the default pytorch + cuda (bool): whether to use cuda or the cpu """ - def __init__(self, encoder_hidden_state, dec_outputs, decoder, tokenizer, onnx=False, cuda=False): + + def __init__( + self, + encoder_hidden_state, + dec_outputs, + decoder, + tokenizer, + onnx=False, + cuda=False, + ): super().__init__() self.encoder_hidden_state = encoder_hidden_state self.dec_outputs = dec_outputs @@ -79,12 +105,21 @@ def __init__(self, encoder_hidden_state, dec_outputs, decoder, tokenizer, onnx=F self.cuda = cuda self.decoder = decoder - def forward(self, max_length, prev_decOut_tensor, temperature=1., repetition_penalty=1., top_k=50, top_p=0, max_context_length=512): - """ Forward function to generate text after a prompt - Args: - prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" - or "translate English to German:" - max_context_length: maximum number of tokens to use as context + def forward( + self, + max_length, + prev_decOut_tensor, + temperature=1.0, + repetition_penalty=1.0, + top_k=50, + top_p=0, + max_context_length=512, + ): + """Forward function to generate text after a prompt + Args: + prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" + or "translate English to German:" + max_context_length: maximum number of tokens to use as context """ with torch.no_grad(): new_tokens = torch.tensor(()) @@ -100,7 +135,9 @@ def forward(self, max_length, prev_decOut_tensor, temperature=1., repetition_pen # Run first decoder loop with LLTFI's output outputs = self.dec_outputs[0] - next_token_logits = outputs[-1, :] / (temperature if temperature > 0 else 1.0) + next_token_logits = outputs[-1, :] / ( + temperature if temperature > 0 else 1.0 + ) new_logits.append(next_token_logits) for _ in set(generated.view(-1).tolist()): @@ -109,18 +146,31 @@ def forward(self, max_length, prev_decOut_tensor, temperature=1., repetition_pen if temperature == 0: # greedy sampling: next_token = torch.argmax(next_token_logits).unsqueeze(0) else: - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) new_tokens = torch.cat((new_tokens, next_token), 0) # Run all subsequent decodr loops using T5-decoder model. for _ in trange(max_length - 1): - outputs = torch.tensor(self.decoder.run(None, {"input_ids": generated.cpu().numpy(), - "encoder_hidden_states": self.encoder_hidden_state})) + outputs = torch.tensor( + self.decoder.run( + None, + { + "input_ids": generated.cpu().numpy(), + "encoder_hidden_states": self.encoder_hidden_state, + }, + ) + ) outputs = outputs[0][0] - next_token_logits = outputs[-1, :] / (temperature if temperature > 0 else 1.0) + next_token_logits = outputs[-1, :] / ( + temperature if temperature > 0 else 1.0 + ) if int(next_token_logits.argmax()) == 1: break new_logits.append(next_token_logits) @@ -129,8 +179,12 @@ def forward(self, max_length, prev_decOut_tensor, temperature=1., repetition_pen if temperature == 0: # greedy sampling: next_token = torch.argmax(next_token_logits).unsqueeze(0) else: - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) new_tokens = torch.cat((new_tokens, next_token), 0) @@ -139,21 +193,29 @@ def forward(self, max_length, prev_decOut_tensor, temperature=1., repetition_pen def main(inpSample): - encoder_sess = InferenceSession('t5-encoder-12.onnx') - decoder_sess = InferenceSession('t5-decoder-with-lm-head-12.onnx') + encoder_sess = InferenceSession("t5-encoder-12.onnx") + decoder_sess = InferenceSession("t5-decoder-with-lm-head-12.onnx") max_context_length = 2048 _, _, tokenizer = get_encoder_decoder_tokenizer() with torch.no_grad(): - generated = torch.tensor(tokenizer(prompts[inpSample])['input_ids'])[:max_context_length - 1].unsqueeze(0) - encoder_hidden_state = encoder_sess.run(None, {"input_ids": generated.cpu().numpy()})[0] + generated = torch.tensor(tokenizer(prompts[inpSample])["input_ids"])[ + : max_context_length - 1 + ].unsqueeze(0) + encoder_hidden_state = encoder_sess.run( + None, {"input_ids": generated.cpu().numpy()} + )[0] # Convert lltfi output to text listResArr = [] - list_of_files = sorted( filter( lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), - os.listdir(PROG_OUT) ), key=lltfi_sort) + list_of_files = sorted( + filter( + lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), os.listdir(PROG_OUT) + ), + key=lltfi_sort, + ) for i in range(len(list_of_files)): - list_of_files[i] = os.path.join(PROG_OUT, list_of_files[i]) + list_of_files[i] = os.path.join(PROG_OUT, list_of_files[i]) for filename in list_of_files: resforSingleInput = [] @@ -161,31 +223,37 @@ def main(inpSample): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) + resforSingleInput.append(value["Data"]) listResArr.append(resforSingleInput) list_output_np = [] # Reshape the output and store as numpy array for elem in listResArr: - output_dec_np = np.asarray(elem[0], dtype=np.float32) - output_dec_np = output_dec_np.reshape(1,-1,32128) - output_dec_tensor = torch.from_numpy(output_dec_np) - list_output_np.append(output_dec_tensor) + output_dec_np = np.asarray(elem[0], dtype=np.float32) + output_dec_np = output_dec_np.reshape(1, -1, 32128) + output_dec_tensor = torch.from_numpy(output_dec_np) + list_output_np.append(output_dec_tensor) # Get predictions final_out_list = [] for elemIndex in range(len(list_output_np)): _, _, tokenizer = get_encoder_decoder_tokenizer() - generative_t5 = GenerativeT5_decoder(encoder_hidden_state, list_output_np[elemIndex], decoder_sess, tokenizer, onnx=True) - tokens, logits = generative_t5(20, prev_dec_out[inpSample], temperature=0.) + generative_t5 = GenerativeT5_decoder( + encoder_hidden_state, + list_output_np[elemIndex], + decoder_sess, + tokenizer, + onnx=True, + ) + tokens, logits = generative_t5(20, prev_dec_out[inpSample], temperature=0.0) final_out = prev_out[inpSample] + " " + tokens final_out_list.append(f"Run #{elemIndex} Prediction:{final_out}\n") - myfile = open('prediction/PredResult.txt', 'w') + myfile = open("prediction/PredResult.txt", "w") myfile.writelines(final_out_list) myfile.close() + if __name__ == "__main__": inpSample = int(sys.argv[1]) main(inpSample) - diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/helper_scripts/generate_inputs.py b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/helper_scripts/generate_inputs.py index ca568777..afbd513b 100644 --- a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/helper_scripts/generate_inputs.py +++ b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/helper_scripts/generate_inputs.py @@ -12,51 +12,71 @@ # WMT19 Dataset from huggingface prompts = [] -prompts.append('translate English to German: I declare resumed the session of the European Parliament.') -prompts.append('translate English to German: Statements by the President') -prompts.append('translate English to French: Statements by the President') -prompts.append('translate English to French: I declare resumed the session of the European Parliament.') -prompts.append('translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake') -prompts.append('translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake') -prompts.append('''translate English to German: (The House rose and observed a minute's silence)''') -prompts.append('''translate English to French: (The House rose and observed a minute's silence)''') -prompts.append('translate English to German: I should like, on behalf of the European Parliament, to express') -prompts.append('translate English to French: I should like, on behalf of the European Parliament, to express') +prompts.append( + "translate English to German: I declare resumed the session of the European Parliament." +) +prompts.append("translate English to German: Statements by the President") +prompts.append("translate English to French: Statements by the President") +prompts.append( + "translate English to French: I declare resumed the session of the European Parliament." +) +prompts.append( + "translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake" +) +prompts.append( + "translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake" +) +prompts.append( + """translate English to German: (The House rose and observed a minute's silence)""" +) +prompts.append( + """translate English to French: (The House rose and observed a minute's silence)""" +) +prompts.append( + "translate English to German: I should like, on behalf of the European Parliament, to express" +) +prompts.append( + "translate English to French: I should like, on behalf of the European Parliament, to express" +) + class GenerativeT5_pytorch(torch.nn.Module): - """ This wrapper utility function implements a single beam search to generate efficiently text. - A lot of the credit goes to the huggingface team and its chief scientist Thomas Wolf whose implementation I based - myself off. - Args: - encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - decoder_with_lm_head: decoder with language model head on top. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - tokenizer: huggingface tokenizer - onnx (bool): whether to use onnx or the default pytorch - cuda (bool): whether to use cuda or the cpu - Examples: - For pytorch: - >>> from transformers import T5Tokenizer - >>> from onnxt5 import create_t5_encoder_decoder, GenerativeT5 - >>> pretrained_model = 't5-base' # This can be a pretrained version, or the path to a huggingface model - >>> simplified_encoder, decoder_with_lm_head = create_t5_encoder_decoder(pretrained_model) - >>> tokenizer = T5Tokenizer.from_pretrained(pretrained_model) - >>> generative_t5 = GenerativeT5(simplified_encoder, decoder_with_lm_head, tokenizer) - >>> generative_t5('translate English to French: I was a victim of a series of accidents.', 16, temperature=0.)[0] - >>> # Output: "Je suis victime d'une série d'accidents." - For onnx: - >>> from transformers import T5Tokenizer - >>> from onnxruntime import InferenceSession - >>> from onnxt5 import GenerativeT5 - >>> decoder_sess = InferenceSession('~/t5-decoder-with-lm-head.onnx') - >>> encoder_sess = InferenceSession('~/t5-encoder.onnx') - >>> tokenizer = T5Tokenizer.from_pretrained(pretrained_model) - >>> generative_t5 = GenerativeT5(encoder_sess, decoder_sess, tokenizer, onnx=True) - >>> generative_t5('translate English to French: I was a victim of a series of accidents.', 16, temperature=0.)[0] - >>> # Output: "Je suis victime d'une série d'accidents." + """This wrapper utility function implements a single beam search to generate efficiently text. + A lot of the credit goes to the huggingface team and its chief scientist Thomas Wolf whose implementation I based + myself off. + Args: + encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + decoder_with_lm_head: decoder with language model head on top. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + tokenizer: huggingface tokenizer + onnx (bool): whether to use onnx or the default pytorch + cuda (bool): whether to use cuda or the cpu + Examples: + For pytorch: + >>> from transformers import T5Tokenizer + >>> from onnxt5 import create_t5_encoder_decoder, GenerativeT5 + >>> pretrained_model = 't5-base' # This can be a pretrained version, or the path to a huggingface model + >>> simplified_encoder, decoder_with_lm_head = create_t5_encoder_decoder(pretrained_model) + >>> tokenizer = T5Tokenizer.from_pretrained(pretrained_model) + >>> generative_t5 = GenerativeT5(simplified_encoder, decoder_with_lm_head, tokenizer) + >>> generative_t5('translate English to French: I was a victim of a series of accidents.', 16, temperature=0.)[0] + >>> # Output: "Je suis victime d'une série d'accidents." + For onnx: + >>> from transformers import T5Tokenizer + >>> from onnxruntime import InferenceSession + >>> from onnxt5 import GenerativeT5 + >>> decoder_sess = InferenceSession('~/t5-decoder-with-lm-head.onnx') + >>> encoder_sess = InferenceSession('~/t5-encoder.onnx') + >>> tokenizer = T5Tokenizer.from_pretrained(pretrained_model) + >>> generative_t5 = GenerativeT5(encoder_sess, decoder_sess, tokenizer, onnx=True) + >>> generative_t5('translate English to French: I was a victim of a series of accidents.', 16, temperature=0.)[0] + >>> # Output: "Je suis victime d'une série d'accidents." """ - def __init__(self, encoder, decoder_with_lm_head, tokenizer, onnx=False, cuda=False): + + def __init__( + self, encoder, decoder_with_lm_head, tokenizer, onnx=False, cuda=False + ): super().__init__() self.encoder = encoder self.decoder_with_lm_head = decoder_with_lm_head @@ -64,23 +84,36 @@ def __init__(self, encoder, decoder_with_lm_head, tokenizer, onnx=False, cuda=Fa self.onnx = onnx self.cuda = cuda - def forward(self, prompt, max_length, temperature=1., repetition_penalty=1., top_k=50, top_p=0, max_context_length=512): - """ Forward function to generate text after a prompt - Args: - prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" - or "translate English to German:" - max_context_length: maximum number of tokens to use as context + def forward( + self, + prompt, + max_length, + temperature=1.0, + repetition_penalty=1.0, + top_k=50, + top_p=0, + max_context_length=512, + ): + """Forward function to generate text after a prompt + Args: + prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" + or "translate English to German:" + max_context_length: maximum number of tokens to use as context """ with torch.no_grad(): new_tokens = torch.tensor(()) new_logits = [] - generated = torch.tensor(self.tokenizer(prompt)['input_ids'])[:max_context_length - 1].unsqueeze(0) + generated = torch.tensor(self.tokenizer(prompt)["input_ids"])[ + : max_context_length - 1 + ].unsqueeze(0) if self.cuda and not self.onnx: generated = generated.cuda() temperature = temperature # Getting encoder past if self.onnx: - encoder_outputs_prompt = self.encoder.run(None, {"input_ids": generated.cpu().numpy()})[0] + encoder_outputs_prompt = self.encoder.run( + None, {"input_ids": generated.cpu().numpy()} + )[0] else: encoder_outputs_prompt = self.encoder(generated) print(encoder_outputs_prompt.shape) @@ -89,19 +122,30 @@ def forward(self, prompt, max_length, temperature=1., repetition_penalty=1., top top_p = top_p # The sequence now needs to start with a - generated = torch.zeros((1,1), dtype=torch.long) + generated = torch.zeros((1, 1), dtype=torch.long) if self.cuda and not self.onnx: generated = generated.cuda() for _ in range(max_length): if self.onnx: - outputs = torch.tensor(self.decoder_with_lm_head.run(None, {"input_ids": generated.cpu().numpy(), - "encoder_hidden_states": encoder_outputs_prompt})[0][0]) + outputs = torch.tensor( + self.decoder_with_lm_head.run( + None, + { + "input_ids": generated.cpu().numpy(), + "encoder_hidden_states": encoder_outputs_prompt, + }, + )[0][0] + ) else: - outputs = self.decoder_with_lm_head(input_ids=generated, - encoder_hidden_states=encoder_outputs_prompt)[0] + outputs = self.decoder_with_lm_head( + input_ids=generated, + encoder_hidden_states=encoder_outputs_prompt, + )[0] print(generated) - next_token_logits = outputs[-1, :] / (temperature if temperature > 0 else 1.0) + next_token_logits = outputs[-1, :] / ( + temperature if temperature > 0 else 1.0 + ) if int(next_token_logits.argmax()) == 1: break new_logits.append(next_token_logits) @@ -110,8 +154,12 @@ def forward(self, prompt, max_length, temperature=1., repetition_penalty=1., top if temperature == 0: # greedy sampling: next_token = torch.argmax(next_token_logits).unsqueeze(0) else: - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) new_tokens = torch.cat((new_tokens, next_token), 0) print(self.tokenizer.decode(new_tokens)) @@ -120,7 +168,7 @@ def forward(self, prompt, max_length, temperature=1., repetition_penalty=1., top prev_dec_out = [] -prev_dec_out.append(np.array([[0, 1674]])) #Ich +prev_dec_out.append(np.array([[0, 1674]])) # Ich prev_dec_out.append(np.array([[0, 28019]])) prev_dec_out.append(np.array([[0, 4829]])) prev_dec_out.append(np.array([[0, 1022]])) @@ -131,36 +179,45 @@ def forward(self, prompt, max_length, temperature=1., repetition_penalty=1., top prev_dec_out.append(np.array([[0, 1674]])) prev_dec_out.append(np.array([[0, 1957]])) + # Get encoder hidden state def createInp(inpSample): - global prompts, prev_dec_out + global prompts, prev_dec_out - encoder_sess = InferenceSession('t5-encoder-12.onnx') - max_context_length = 2048 - _, _, tokenizer = get_encoder_decoder_tokenizer() - with torch.no_grad(): - generated = torch.tensor(tokenizer(prompts[inpSample])['input_ids'])[:max_context_length - 1].unsqueeze(0) - encoder_hidden_state = encoder_sess.run(None, {"input_ids": generated.cpu().numpy()})[0] + encoder_sess = InferenceSession("t5-encoder-12.onnx") + max_context_length = 2048 + _, _, tokenizer = get_encoder_decoder_tokenizer() + with torch.no_grad(): + generated = torch.tensor(tokenizer(prompts[inpSample])["input_ids"])[ + : max_context_length - 1 + ].unsqueeze(0) + encoder_hidden_state = encoder_sess.run( + None, {"input_ids": generated.cpu().numpy()} + )[0] + + inp_1_tensor = numpy_helper.from_array(prev_dec_out[inpSample]) + with open("input{}_0.pb".format(inpSample), "wb") as file: + file.write(inp_1_tensor.SerializeToString()) - inp_1_tensor = numpy_helper.from_array(prev_dec_out[inpSample]) - with open("input{}_0.pb".format(inpSample), 'wb') as file: - file.write(inp_1_tensor.SerializeToString()) + encoder_hidden_state_tensor = numpy_helper.from_array(encoder_hidden_state) + with open("input{}_1.pb".format(inpSample), "wb") as file: + file.write(encoder_hidden_state_tensor.SerializeToString()) - encoder_hidden_state_tensor = numpy_helper.from_array(encoder_hidden_state) - with open("input{}_1.pb".format(inpSample), 'wb') as file: - file.write(encoder_hidden_state_tensor.SerializeToString()) def main(): - global prompts - for i in range(0, 10): - # Use encoder and decoder to generate output text - decoder_sess = InferenceSession('t5-decoder-with-lm-head-12.onnx') - encoder_sess = InferenceSession('t5-encoder-12.onnx') - _, _, tokenizer = get_encoder_decoder_tokenizer() - generative_t5 = GenerativeT5_pytorch(encoder_sess, decoder_sess, tokenizer, onnx=True) - token, logit = generative_t5(prompts[i], 16, temperature=0.) - #pdb.set_trace() - createInp(i) + global prompts + for i in range(0, 10): + # Use encoder and decoder to generate output text + decoder_sess = InferenceSession("t5-decoder-with-lm-head-12.onnx") + encoder_sess = InferenceSession("t5-encoder-12.onnx") + _, _, tokenizer = get_encoder_decoder_tokenizer() + generative_t5 = GenerativeT5_pytorch( + encoder_sess, decoder_sess, tokenizer, onnx=True + ) + token, logit = generative_t5(prompts[i], 16, temperature=0.0) + # pdb.set_trace() + createInp(i) + if __name__ == "__main__": main() diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/input.c b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/input.c index c8b86f90..141c1fb4 100644 --- a/sample_programs/ml_sample_programs/nlp_models/t5-decoder/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/t5-decoder/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/compile.sh b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/compile.sh index a55b0231..fc76dc16 100755 --- a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/compile.sh +++ b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/compile.sh @@ -17,7 +17,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp t5-encoder-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp t5-encoder-12.onnx mlir-translate -mlir-to-llvmir t5-encoder-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/createPredFile.py b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/createPredFile.py index 2c57b63f..9a19fcec 100644 --- a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/createPredFile.py +++ b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/createPredFile.py @@ -10,48 +10,77 @@ import json, pdb ROOT = os.getcwd() -LLFI_OUT = os.path.join(ROOT, 'llfi') -PROG_OUT = os.path.join(LLFI_OUT, 'prog_output') +LLFI_OUT = os.path.join(ROOT, "llfi") +PROG_OUT = os.path.join(LLFI_OUT, "prog_output") # WMT19 Dataset from huggingface prompts = [] -prompts.append('translate English to German: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000.') -prompts.append('translate English to German: Statements by the President') -prompts.append('translate English to French: Statements by the President') -prompts.append('translate English to French: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000.') -prompts.append('translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century.') -prompts.append('translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century.') -prompts.append('''translate English to German: (The House rose and observed a minute's silence)''') -prompts.append('''translate English to French: (The House rose and observed a minute's silence)''') -prompts.append('translate English to German: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims.') -prompts.append('translate English to French: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims.') - -def lltfi_sort(elem): - return int(elem.split('layeroutput')[-1].split('-')[-1].split('.txt')[0]) +prompts.append( + "translate English to German: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000." +) +prompts.append("translate English to German: Statements by the President") +prompts.append("translate English to French: Statements by the President") +prompts.append( + "translate English to French: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000." +) +prompts.append( + "translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century." +) +prompts.append( + "translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century." +) +prompts.append( + """translate English to German: (The House rose and observed a minute's silence)""" +) +prompts.append( + """translate English to French: (The House rose and observed a minute's silence)""" +) +prompts.append( + "translate English to German: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims." +) +prompts.append( + "translate English to French: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims." +) + + +def lltfi_sort(elem): + return int(elem.split("layeroutput")[-1].split("-")[-1].split(".txt")[0]) + class GenerativeT5_custom_encoder(torch.nn.Module): - """ Code Ref: https://github.com/abelriboulot/onnxt5 - Args: - encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - decoder_with_lm_head: decoder with language model head on top. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - tokenizer: huggingface tokenizer - onnx (bool): whether to use onnx or the default pytorch - cuda (bool): whether to use cuda or the cpu""" - def __init__(self, encoder_outputs_prompt, decoder_with_lm_head, tokenizer, cuda=False): + """Code Ref: https://github.com/abelriboulot/onnxt5 + Args: + encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + decoder_with_lm_head: decoder with language model head on top. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + tokenizer: huggingface tokenizer + onnx (bool): whether to use onnx or the default pytorch + cuda (bool): whether to use cuda or the cpu""" + + def __init__( + self, encoder_outputs_prompt, decoder_with_lm_head, tokenizer, cuda=False + ): super().__init__() self.encoder_outputs_prompt = encoder_outputs_prompt self.decoder_with_lm_head = decoder_with_lm_head self.tokenizer = tokenizer self.cuda = cuda - def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, top_p=0, max_context_length=512): - """ Forward function to generate text after a prompt - Args: - prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" - or "translate English to German:" - max_context_length: maximum number of tokens to use as context + def forward( + self, + max_length, + temperature=1.0, + repetition_penalty=1.0, + top_k=50, + top_p=0, + max_context_length=512, + ): + """Forward function to generate text after a prompt + Args: + prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" + or "translate English to German:" + max_context_length: maximum number of tokens to use as context """ with torch.no_grad(): new_tokens = torch.tensor(()) @@ -62,16 +91,24 @@ def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, t top_p = top_p # The sequence now needs to start with a - generated = torch.zeros((1,1), dtype=torch.long) + generated = torch.zeros((1, 1), dtype=torch.long) if self.cuda and not self.onnx: generated = generated.cuda() - for _ in range(max_length): - outputs = torch.tensor(self.decoder_with_lm_head.run(None, {"input_ids": generated.cpu().numpy(), - "encoder_hidden_states": self.encoder_outputs_prompt})) + outputs = torch.tensor( + self.decoder_with_lm_head.run( + None, + { + "input_ids": generated.cpu().numpy(), + "encoder_hidden_states": self.encoder_outputs_prompt, + }, + ) + ) outputs = outputs[0][0] - next_token_logits = outputs[-1, :] / (temperature if temperature > 0 else 1.0) + next_token_logits = outputs[-1, :] / ( + temperature if temperature > 0 else 1.0 + ) if int(next_token_logits.argmax()) == 1: break new_logits.append(next_token_logits) @@ -80,8 +117,12 @@ def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, t if temperature == 0: # greedy sampling: next_token = torch.argmax(next_token_logits).unsqueeze(0) else: - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) new_tokens = torch.cat((new_tokens, next_token), 0) @@ -91,8 +132,12 @@ def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, t def main(): # Get LLTFI outputs in listResArr listResArr = [] - list_of_files = sorted( filter( lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), - os.listdir(PROG_OUT) ), key=lltfi_sort ) + list_of_files = sorted( + filter( + lambda x: os.path.isfile(os.path.join(PROG_OUT, x)), os.listdir(PROG_OUT) + ), + key=lltfi_sort, + ) for i in range(len(list_of_files)): list_of_files[i] = os.path.join(PROG_OUT, list_of_files[i]) @@ -103,28 +148,31 @@ def main(): resultJson = json.load(read_file) for key, value in resultJson.items(): - resforSingleInput.append(value['Data']) + resforSingleInput.append(value["Data"]) listResArr.append(resforSingleInput) list_output_np = [] # Reshape the output and store as numpy array for elem in listResArr: output_np = np.asarray(elem[0], dtype=np.float32) - output_np = output_np.reshape(1,-1,768) + output_np = output_np.reshape(1, -1, 768) list_output_np.append(output_np) # Script to convert numpy output to text listPreds = [] - decoder_sess = InferenceSession('t5-decoder-with-lm-head-12.onnx') + decoder_sess = InferenceSession("t5-decoder-with-lm-head-12.onnx") _, _, tokenizer = get_encoder_decoder_tokenizer() for elemIndex in range(len(list_output_np)): - generative_t5 = GenerativeT5_custom_encoder(list_output_np[elemIndex], decoder_sess, tokenizer) - output_text = generative_t5(30, temperature=0.)[0] + generative_t5 = GenerativeT5_custom_encoder( + list_output_np[elemIndex], decoder_sess, tokenizer + ) + output_text = generative_t5(30, temperature=0.0)[0] listPreds.append(f"Run #{elemIndex} Prediction:{output_text}\n") - myfile = open('prediction/PredResult.txt', 'w') + myfile = open("prediction/PredResult.txt", "w") myfile.writelines(listPreds) myfile.close() + if __name__ == "__main__": main() diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/helper_scripts/generate_inputs.py b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/helper_scripts/generate_inputs.py index c333e473..f1979980 100644 --- a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/helper_scripts/generate_inputs.py +++ b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/helper_scripts/generate_inputs.py @@ -12,41 +12,68 @@ # WMT19 Dataset from huggingface prompts = [] -prompts.append('translate English to German: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000.') -prompts.append('translate English to German: Statements by the President') -prompts.append('translate English to French: Statements by the President') -prompts.append('translate English to French: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000.') -prompts.append('translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century.') -prompts.append('translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century.') -prompts.append('''translate English to German: (The House rose and observed a minute's silence)''') -prompts.append('''translate English to French: (The House rose and observed a minute's silence)''') -prompts.append('translate English to German: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims.') -prompts.append('translate English to French: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims.') +prompts.append( + "translate English to German: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000." +) +prompts.append("translate English to German: Statements by the President") +prompts.append("translate English to French: Statements by the President") +prompts.append( + "translate English to French: I declare resumed the session of the European Parliament adjourned on Friday, 15 December 2000." +) +prompts.append( + "translate English to German: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century." +) +prompts.append( + "translate English to French: Ladies and gentlemen, on Saturday, as you know, an earthquake struck Central America once again, with tragic consequences. This is an area which has already been seriously affected on a number of occasions since the beginning of the twentieth century." +) +prompts.append( + """translate English to German: (The House rose and observed a minute's silence)""" +) +prompts.append( + """translate English to French: (The House rose and observed a minute's silence)""" +) +prompts.append( + "translate English to German: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims." +) +prompts.append( + "translate English to French: I should like, on behalf of the European Parliament, to express our sympathy to the parents and families of the victims." +) class GenerativeT5_custom_encoder(torch.nn.Module): - """ Code Ref: https://github.com/abelriboulot/onnxt5 - Args: - encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - decoder_with_lm_head: decoder with language model head on top. Can be obtained with the - create_t5_encoder_decoder utility function for pytorch, see examples below. - tokenizer: huggingface tokenizer - onnx (bool): whether to use onnx or the default pytorch - cuda (bool): whether to use cuda or the cpu""" - def __init__(self, encoder_outputs_prompt, decoder_with_lm_head, tokenizer, cuda=False): + """Code Ref: https://github.com/abelriboulot/onnxt5 + Args: + encoder: huggingface encoder or onnx session for the encoder of T5. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + decoder_with_lm_head: decoder with language model head on top. Can be obtained with the + create_t5_encoder_decoder utility function for pytorch, see examples below. + tokenizer: huggingface tokenizer + onnx (bool): whether to use onnx or the default pytorch + cuda (bool): whether to use cuda or the cpu""" + + def __init__( + self, encoder_outputs_prompt, decoder_with_lm_head, tokenizer, cuda=False + ): super().__init__() self.encoder_outputs_prompt = encoder_outputs_prompt self.decoder_with_lm_head = decoder_with_lm_head self.tokenizer = tokenizer self.cuda = cuda - def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, top_p=0, max_context_length=512): - """ Forward function to generate text after a prompt - Args: - prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" - or "translate English to German:" - max_context_length: maximum number of tokens to use as context + def forward( + self, + max_length, + temperature=1.0, + repetition_penalty=1.0, + top_k=50, + top_p=0, + max_context_length=512, + ): + """Forward function to generate text after a prompt + Args: + prompt: str to run (don't forget to add at the beginning the task to run such as "summarize:" + or "translate English to German:" + max_context_length: maximum number of tokens to use as context """ with torch.no_grad(): new_tokens = torch.tensor(()) @@ -57,16 +84,24 @@ def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, t top_p = top_p # The sequence now needs to start with a - generated = torch.zeros((1,1), dtype=torch.long) + generated = torch.zeros((1, 1), dtype=torch.long) if self.cuda and not self.onnx: generated = generated.cuda() - for _ in range(max_length): - outputs = torch.tensor(self.decoder_with_lm_head.run(None, {"input_ids": generated.cpu().numpy(), - "encoder_hidden_states": self.encoder_outputs_prompt})) + outputs = torch.tensor( + self.decoder_with_lm_head.run( + None, + { + "input_ids": generated.cpu().numpy(), + "encoder_hidden_states": self.encoder_outputs_prompt, + }, + ) + ) outputs = outputs[0][0] - next_token_logits = outputs[-1, :] / (temperature if temperature > 0 else 1.0) + next_token_logits = outputs[-1, :] / ( + temperature if temperature > 0 else 1.0 + ) if int(next_token_logits.argmax()) == 1: break new_logits.append(next_token_logits) @@ -75,8 +110,12 @@ def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, t if temperature == 0: # greedy sampling: next_token = torch.argmax(next_token_logits).unsqueeze(0) else: - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + filtered_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p + ) + next_token = torch.multinomial( + F.softmax(filtered_logits, dim=-1), num_samples=1 + ) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) new_tokens = torch.cat((new_tokens, next_token), 0) @@ -84,19 +123,23 @@ def forward(self, max_length, temperature=1., repetition_penalty=1., top_k=50, t def main(): - global prompts + global prompts - encoder_input = [] - max_context_length = 512 - _, _, tokenizer = get_encoder_decoder_tokenizer() - # Tokenize the prompts - for prompt in prompts: - encoder_input.append(torch.tensor(tokenizer(prompt)['input_ids'])[:max_context_length - 1].unsqueeze(0)) + encoder_input = [] + max_context_length = 512 + _, _, tokenizer = get_encoder_decoder_tokenizer() + # Tokenize the prompts + for prompt in prompts: + encoder_input.append( + torch.tensor(tokenizer(prompt)["input_ids"])[ + : max_context_length - 1 + ].unsqueeze(0) + ) - for index in range(len(encoder_input)): - tensor = numpy_helper.from_array(encoder_input[index].cpu().numpy()) - with open("input_{}.pb".format(index), 'wb') as file: - file.write(tensor.SerializeToString()) + for index in range(len(encoder_input)): + tensor = numpy_helper.from_array(encoder_input[index].cpu().numpy()) + with open("input_{}.pb".format(index), "wb") as file: + file.write(tensor.SerializeToString()) if __name__ == "__main__": diff --git a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/input.c b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/input.c index cfe31237..23aa7605 100644 --- a/sample_programs/ml_sample_programs/nlp_models/t5-encoder/input.c +++ b/sample_programs/ml_sample_programs/nlp_models/t5-encoder/input.c @@ -180,7 +180,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/vision_models/bvlcalexnet-12/compile.sh b/sample_programs/ml_sample_programs/vision_models/bvlcalexnet-12/compile.sh index acc6478e..d213aebd 100755 --- a/sample_programs/ml_sample_programs/vision_models/bvlcalexnet-12/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/bvlcalexnet-12/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp bvlcalexnet-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp bvlcalexnet-12.onnx mlir-translate -mlir-to-llvmir bvlcalexnet-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/cnn-fmnist.py b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/cnn-fmnist.py index df6c1521..2570048c 100644 --- a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/cnn-fmnist.py +++ b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/cnn-fmnist.py @@ -1,11 +1,15 @@ from tensorflow.keras import datasets, layers, models, losses + def process_images(images): - images = images.reshape((-1, 28, 28, 1)) - images = images / 255.0 - return images + images = images.reshape((-1, 28, 28, 1)) + images = images / 255.0 + return images + -(train_images, train_labels), (test_images, test_labels) = datasets.fashion_mnist.load_data() +(train_images, train_labels), (test_images, test_labels) = ( + datasets.fashion_mnist.load_data() +) train_images = process_images(train_images) test_images = process_images(test_images) @@ -19,14 +23,18 @@ def process_images(images): model.add(layers.Dense(10, activation="softmax")) model.compile( -optimizer="adam", -loss=losses.SparseCategoricalCrossentropy(), -metrics=["accuracy"], + optimizer="adam", + loss=losses.SparseCategoricalCrossentropy(), + metrics=["accuracy"], ) # Save the untrained weights for future training with modified dataset -model.fit(train_images, train_labels, batch_size=100, epochs=10, - validation_data=(test_images, test_labels)) - -model.save('./cnn-fmnist.tf') +model.fit( + train_images, + train_labels, + batch_size=100, + epochs=10, + validation_data=(test_images, test_labels), +) +model.save("./cnn-fmnist.tf") diff --git a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/compile.sh b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/compile.sh index 8665f760..16dabd38 100755 --- a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/compile.sh @@ -1,6 +1,6 @@ printf "\n[Compile Script]: Convert ONNX model to LLVM IR\n" python3 ../../../../tools/ExtendONNXModel.py --model_path ./model.onnx --output_model_path ./extendedmodel.onnx > expected_op_seq.txt -onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp +onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp mlir-translate -mlir-to-llvmir extendedmodel.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/image.c b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/image.c index aac7662f..9f1c5292 100644 --- a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/image.c +++ b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/image.c @@ -157,7 +157,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/stb_image.h b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/stb_image.h index accef483..5b891039 100644 --- a/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/stb_image.h +++ b/sample_programs/ml_sample_programs/vision_models/cnn-fmnist/stb_image.h @@ -3,7 +3,8 @@ Do this: #define STB_IMAGE_IMPLEMENTATION - before you include this file in *one* C or C++ file to create the implementation. + before you include this file in *one* C or C++ file to create the +implementation. // i.e. it should look like this: #include ... @@ -13,15 +14,16 @@ #include "stb_image.h" You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. - And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using +malloc,realloc,free QUICK NOTES: Primarily of interest to game developers and other people who can avoid problematic images and only need the trivial interface - JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) - PNG 1/2/4/8/16-bit-per-channel + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as +stock IJG lib) PNG 1/2/4/8/16-bit-per-channel TGA (not sure what subset, if a subset) BMP non-1bpp, non-RLE @@ -50,25 +52,22 @@ RECENT REVISION HISTORY: 2.26 (2020-07-13) many minor fixes 2.25 (2020-02-02) fix warnings - 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically - 2.23 (2019-08-11) fix clang static analysis warning - 2.22 (2019-03-04) gif fixes, fix warnings - 2.21 (2019-02-25) fix typo in comment - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and +flip_vertically 2.23 (2019-08-11) fix clang static analysis warning 2.22 +(2019-03-04) gif fixes, fix warnings 2.21 (2019-02-25) fix typo in comment 2.20 +(2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix warnings 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings - 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes - 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 - RGB-format JPEG; remove white matting in PSD; - allocate large structures on the stack; - correct channel count for PNG & BMP - 2.10 (2016-01-22) avoid warning introduced in 2.09 - 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; +bugfixes 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE +detection on GCC 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for +Imagenet JPGs 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; +fixes 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 +(2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 RGB-format JPEG; remove +white matting in PSD; allocate large structures on the stack; correct channel +count for PNG & BMP 2.10 (2016-01-22) avoid warning introduced in 2.09 2.09 +(2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED See end of file for full revision history. @@ -86,38 +85,37 @@ RECENT REVISION HISTORY: github:urraka (animated gif) Junggon Kim (PNM comments) Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) socks-the-fox (16-bit PNG) - Jeremy Sawicki (handle all ImageNet JPGs) - Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Jeremy Sawicki (handle all ImageNet +JPGs) Optimizations & bugfixes Mikhail Morozov (1-bit BMP) Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) Arseny Kapoulkine John-Mark Allen Carmelo J Fdez-Aguera Bug & warning fixes - Marc LeBlanc David Woo Guillaume George Martins Mozeiko - Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski - Phil Jordan Dave Moore Roy Eltham - Hayaki Saito Nathan Reed Won Chun - Luke Graham Johan Duparc Nick Verigakis the Horde3D community - Thomas Ruf Ronny Chevalier github:rlyeh - Janez Zemva John Bartholomew Michal Cichon github:romigrou - Jonathan Blow Ken Hamada Tero Hanninen github:svdijk - Laurent Gomila Cort Stratton github:snagar - Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex - Cass Everitt Ryamond Barbiero github:grim210 - Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw - Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus - Josh Tobin Matthew Gregan github:poppolopoppo - Julian Raschke Gregory Mullen Christian Floisand github:darealshinji - Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 - Brad Weinberger Matvey Cherevko [reserved] - Luca Sas Alexander Veselov Zack Middleton [reserved] + Marc LeBlanc David Woo Guillaume George Martins +Mozeiko Christpher Lloyd Jerry Jansson Joseph Thomson Blazej +Dariusz Roszkowski Phil Jordan Dave Moore Roy +Eltham Hayaki Saito Nathan Reed Won Chun Luke Graham Johan +Duparc Nick Verigakis the Horde3D community Thomas Ruf Ronny +Chevalier github:rlyeh Janez Zemva John +Bartholomew Michal Cichon github:romigrou Jonathan Blow Ken +Hamada Tero Hanninen github:svdijk Laurent Gomila Cort +Stratton github:snagar Aruelien Pocheville Sergio Gonzalez Thibault +Reuille github:Zelex Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Matthew Gregan +github:poppolopoppo Julian Raschke Gregory Mullen Christian +Floisand github:darealshinji Baldur Karlsson Kevin Schmidt JR +Smith github:Michaelangel007 Brad Weinberger Matvey Cherevko +[reserved] Luca Sas Alexander Veselov Zack Middleton [reserved] Ryan C. Gordon [reserved] [reserved] DO NOT ADD YOUR NAME HERE - To add your name to the credits, pick a random blank space in the middle and fill it. - 80% of merge conflicts on stb PRs are due to people adding their name at the end - of the credits. + To add your name to the credits, pick a random blank space in the middle and +fill it. 80% of merge conflicts on stb PRs are due to people adding their name +at the end of the credits. */ #ifndef STBI_INCLUDE_STB_IMAGE_H @@ -136,14 +134,15 @@ RECENT REVISION HISTORY: // // ... process data if not NULL ... // // ... x = width, y = height, n = # 8-bit components per pixel ... // // ... replace '0' with '1'..'4' to force that many components per pixel -// // ... but 'n' will always be the number that it would have been if you said 0 -// stbi_image_free(data) +// // ... but 'n' will always be the number that it would have been if you +// said 0 stbi_image_free(data) // // Standard parameters: // int *x -- outputs image width in pixels // int *y -- outputs image height in pixels // int *channels_in_file -- outputs # of image components in image file -// int desired_channels -- if non-zero, # of image components requested in result +// int desired_channels -- if non-zero, # of image components requested in +// result // // The return value from an image loader is an 'unsigned char *' which points // to the pixel data, or NULL on an allocation failure or if the image is @@ -171,8 +170,8 @@ RECENT REVISION HISTORY: // and *x, *y, *channels_in_file will be unchanged. The function // stbi_failure_reason() can be queried for an extremely brief, end-user // unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS -// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly -// more user-friendly ones. +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get +// slightly more user-friendly ones. // // Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. // @@ -196,11 +195,12 @@ RECENT REVISION HISTORY: // 2. easy to maintain // 3. good performance // -// Sometimes I let "good performance" creep up in priority over "easy to maintain", -// and for best performance I may provide less-easy-to-use APIs that give higher -// performance, in addition to the easy-to-use ones. Nevertheless, it's important -// to keep in mind that from the standpoint of you, a client of this library, -// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// Sometimes I let "good performance" creep up in priority over "easy to +// maintain", and for best performance I may provide less-easy-to-use APIs that +// give higher performance, in addition to the easy-to-use ones. Nevertheless, +// it's important to keep in mind that from the standpoint of you, a client of +// this library, all you care about is #1 and #3, and stb libraries DO NOT +// emphasize #3 above all. // // Some secondary priorities arise directly from the first two, some of which // provide more explicit reasons why performance can't be emphasized. @@ -219,7 +219,8 @@ RECENT REVISION HISTORY: // overhead. // // The three functions you must define are "read" (reads some bytes of data), -// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the +// end). // // =========================================================================== // @@ -247,10 +248,11 @@ RECENT REVISION HISTORY: // HDR image support (disable by defining STBI_NO_HDR) // // stb_image supports loading HDR images in general, and currently the Radiance -// .HDR file format specifically. You can still load any file through the existing -// interface; if you attempt to load an HDR file, it will be automatically remapped -// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; -// both of these constants can be reconfigured through this interface: +// .HDR file format specifically. You can still load any file through the +// existing interface; if you attempt to load an HDR file, it will be +// automatically remapped to LDR, assuming gamma 2.2 and an arbitrary scale +// factor defaulting to 1; both of these constants can be reconfigured through +// this interface: // // stbi_hdr_to_ldr_gamma(2.2f); // stbi_hdr_to_ldr_scale(1.0f); @@ -342,14 +344,13 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum -{ - STBI_default = 0, // only used for desired_channels +enum { + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -377,11 +378,13 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct -{ - int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof) (void *user); // returns nonzero if we are at end of file/data +typedef struct { + int (*read)(void *user, char *data, + int size); // fill 'data' with 'size' bytes. return number of + // bytes actually read + void (*skip)(void *user, int n); // skip the next 'n' bytes, or 'unget' the + // last -n bytes if negative + int (*eof)(void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -389,21 +392,33 @@ typedef struct // 8-bits-per-channel interface // -STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); -// for stbi_load_from_file, file pointer is left pointing immediately after image +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after +// image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input); #endif //////////////////////////////////// @@ -411,12 +426,20 @@ STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wch // 16-bits-per-channel interface // -STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); #endif //////////////////////////////////// @@ -424,83 +447,102 @@ STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_i // float-per-channel interface // #ifndef STBI_NO_LINEAR - STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); - #ifndef STBI_NO_STDIO - STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); - #endif +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +#endif #endif #ifndef STBI_NO_HDR - STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); - STBIDEF void stbi_hdr_to_ldr_scale(float scale); +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR - STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); - STBIDEF void stbi_ldr_to_hdr_scale(float scale); +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename); -STBIDEF int stbi_is_hdr_from_file(FILE *f); +STBIDEF int stbi_is_hdr(char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO - // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char *stbi_failure_reason (void); +STBIDEF const char *stbi_failure_reason(void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free (void *retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, + int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, + void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit (char const *filename); -STBIDEF int stbi_is_16_bit_from_file(FILE *f); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit(char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif - - // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); // indicate whether we should process iphone images back to canonical format, // or just pass them through "as-is" STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); -// flip the image vertically, so the first pixel in the output array is the bottom left +// flip the image vertically, so the first pixel in the output array is the +// bottom left STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); -// as above, but only applies to images loaded on the thread that calls the function -// this function is only available if your compiler supports thread-local variables; -// calling it will fail to link if your compiler doesn't -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); +// as above, but only applies to images loaded on the thread that calls the +// function this function is only available if your compiler supports +// thread-local variables; calling it will fail to link if your compiler doesn't +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); // ZLIB client - used by PNG, available for other purposes -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header); STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); - -STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, + int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); #ifdef __cplusplus } @@ -513,52 +555,53 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ - || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ - || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ - || defined(STBI_ONLY_ZLIB) - #ifndef STBI_ONLY_JPEG - #define STBI_NO_JPEG - #endif - #ifndef STBI_ONLY_PNG - #define STBI_NO_PNG - #endif - #ifndef STBI_ONLY_BMP - #define STBI_NO_BMP - #endif - #ifndef STBI_ONLY_PSD - #define STBI_NO_PSD - #endif - #ifndef STBI_ONLY_TGA - #define STBI_NO_TGA - #endif - #ifndef STBI_ONLY_GIF - #define STBI_NO_GIF - #endif - #ifndef STBI_ONLY_HDR - #define STBI_NO_HDR - #endif - #ifndef STBI_ONLY_PIC - #define STBI_NO_PIC - #endif - #ifndef STBI_ONLY_PNM - #define STBI_NO_PNM - #endif -#endif - -#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) -#define STBI_NO_ZLIB +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || \ + defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || \ + defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ + defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) +#ifndef STBI_ONLY_JPEG +#define STBI_NO_JPEG +#endif +#ifndef STBI_ONLY_PNG +#define STBI_NO_PNG +#endif +#ifndef STBI_ONLY_BMP +#define STBI_NO_BMP +#endif +#ifndef STBI_ONLY_PSD +#define STBI_NO_PSD +#endif +#ifndef STBI_ONLY_TGA +#define STBI_NO_TGA +#endif +#ifndef STBI_ONLY_GIF +#define STBI_NO_GIF +#endif +#ifndef STBI_ONLY_HDR +#define STBI_NO_HDR +#endif +#ifndef STBI_ONLY_PIC +#define STBI_NO_PIC +#endif +#ifndef STBI_ONLY_PNM +#define STBI_NO_PNM +#endif #endif +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && \ + !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif +#include #include #include // ptrdiff_t on osx #include #include -#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -576,55 +619,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #define STBI_EXTERN extern #endif - #ifndef _MSC_VER - #ifdef __cplusplus - #define stbi_inline inline - #else - #define stbi_inline - #endif +#ifdef __cplusplus +#define stbi_inline inline +#else +#define stbi_inline +#endif #else - #define stbi_inline __forceinline +#define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS - #if defined(__cplusplus) && __cplusplus >= 201103L - #define STBI_THREAD_LOCAL thread_local - #elif defined(__GNUC__) && __GNUC__ < 5 - #define STBI_THREAD_LOCAL __thread - #elif defined(_MSC_VER) - #define STBI_THREAD_LOCAL __declspec(thread) - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) - #define STBI_THREAD_LOCAL _Thread_local - #endif - - #ifndef STBI_THREAD_LOCAL - #if defined(__GNUC__) - #define STBI_THREAD_LOCAL __thread - #endif - #endif +#if defined(__cplusplus) && __cplusplus >= 201103L +#define STBI_THREAD_LOCAL thread_local +#elif defined(__GNUC__) && __GNUC__ < 5 +#define STBI_THREAD_LOCAL __thread +#elif defined(_MSC_VER) +#define STBI_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && \ + !defined(__STDC_NO_THREADS__) +#define STBI_THREAD_LOCAL _Thread_local +#endif + +#ifndef STBI_THREAD_LOCAL +#if defined(__GNUC__) +#define STBI_THREAD_LOCAL __thread +#endif +#endif #endif #ifdef _MSC_VER typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -632,27 +675,30 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL - #define stbi_lrot(x,y) _lrotl(x,y) +#define stbi_lrot(x, y) _lrotl(x, y) #else - #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (32 - (y)))) +#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (32 - (y)))) #endif -#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +#if defined(STBI_MALLOC) && defined(STBI_FREE) && \ + (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) // ok -#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && \ + !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) // ok #else -#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#error \ + "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p,newsz) realloc(p,newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p, newsz) realloc(p, newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) #endif // x86/x64 detection @@ -662,7 +708,8 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI__X86_TARGET #endif -#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && \ + !defined(STBI_NO_SIMD) // gcc doesn't support sse2 intrinsics unless you compile with -msse2, // which in turn means it gets to use SSE2 everywhere. This is unfortunate, // but previous attempts to provide the SSE2 functions with runtime @@ -673,8 +720,10 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI_NO_SIMD #endif -#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) -// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && \ + !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid +// STBI__X64_TARGET // // 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the // Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. @@ -684,44 +733,43 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; // See https://github.com/nothings/stb/issues/81 for more information. // // So default to no SSE2 on 32-bit MinGW. If you've read this far and added -// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +// -mstackrealign to your build settings, feel free to #define +// STBI_MINGW_ENABLE_SSE2. #define STBI_NO_SIMD #endif -#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#if !defined(STBI_NO_SIMD) && \ + (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) #define STBI_SSE2 #include #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) -{ - int info[4]; - __cpuid(info,1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) { + int info[4]; + __cpuid(info, 1); + return info[3]; } #else -static int stbi__cpuid3(void) -{ - int res; - __asm { +static int stbi__cpuid3(void) { + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) { + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -729,12 +777,11 @@ static int stbi__sse2_available(void) #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) { + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -766,188 +813,182 @@ static int stbi__sse2_available(void) // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct -{ - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct { + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void *io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; - static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) -{ - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) { + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) -{ - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, + void *user) { + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void *user, char *data, int size) -{ - return (int) fread(data,1,size,(FILE*) user); +static int stbi__stdio_read(void *user, char *data, int size) { + return (int)fread(data, 1, size, (FILE *)user); } -static void stbi__stdio_skip(void *user, int n) -{ - int ch; - fseek((FILE*) user, n, SEEK_CUR); - ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ - } +static void stbi__stdio_skip(void *user, int n) { + int ch; + fseek((FILE *)user, n, SEEK_CUR); + ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ + } } -static int stbi__stdio_eof(void *user) -{ - return feof((FILE*) user) || ferror((FILE *) user); +static int stbi__stdio_eof(void *user) { + return feof((FILE *)user) || ferror((FILE *)user); } -static stbi_io_callbacks stbi__stdio_callbacks = -{ - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static stbi_io_callbacks stbi__stdio_callbacks = { + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context *s, FILE *f) -{ - stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +static void stbi__start_file(stbi__context *s, FILE *f) { + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } -//static void stop_file(stbi__context *s) { } +// static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context *s) -{ - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 + // bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum -{ - STBI_ORDER_RGB, - STBI_ORDER_BGR -}; +enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; -typedef struct -{ - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct { + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context *s); -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context *s); -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__png_is16(stbi__context *s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context *s); -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context *s); -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s); -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__psd_is16(stbi__context *s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context *s); -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context *s); -static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context *s); -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s); -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); #endif static #ifdef STBI_THREAD_LOCAL -STBI_THREAD_LOCAL + STBI_THREAD_LOCAL #endif -const char *stbi__g_failure_reason; + const char *stbi__g_failure_reason; -STBIDEF const char *stbi_failure_reason(void) -{ - return stbi__g_failure_reason; +STBIDEF const char *stbi_failure_reason(void) { + return stbi__g_failure_reason; } #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char *str) -{ - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) { + stbi__g_failure_reason = str; + return 0; } #endif -static void *stbi__malloc(size_t size) -{ - return STBI_MALLOC(size); +static void *stbi__malloc(size_t size) { + return STBI_MALLOC(size); } // stb_image uses ints pervasively, including for offset calculations. @@ -962,70 +1003,72 @@ static void *stbi__malloc(size_t size) // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) -{ - if (b < 0) return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) { + if (b < 0) + return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) -{ - if (a < 0 || b < 0) return 0; - if (b == 0) return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX/b; +static int stbi__mul2sizes_valid(int a, int b) { + if (a < 0 || b < 0) + return 0; + if (b == 0) + return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; } -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__addsizes_valid(a*b*c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__addsizes_valid(a * b * c, add); } -// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't +// overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__mul2sizes_valid(a * b * c, d) && + stbi__addsizes_valid(a * b * c * d, add); } #endif -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void *stbi__malloc_mad2(int a, int b, int add) -{ - if (!stbi__mad2sizes_valid(a, b, add)) return NULL; - return stbi__malloc(a*b + add); +static void *stbi__malloc_mad2(int a, int b, int add) { + if (!stbi__mad2sizes_valid(a, b, add)) + return NULL; + return stbi__malloc(a * b + add); } #endif -static void *stbi__malloc_mad3(int a, int b, int c, int add) -{ - if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; - return stbi__malloc(a*b*c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) { + if (!stbi__mad3sizes_valid(a, b, c, add)) + return NULL; + return stbi__malloc(a * b * c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) -{ - if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; - return stbi__malloc(a*b*c*d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) { + if (!stbi__mad4sizes_valid(a, b, c, d, add)) + return NULL; + return stbi__malloc(a * b * c * d + add); } #endif @@ -1034,417 +1077,459 @@ static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS - #define stbi__err(x,y) 0 +#define stbi__err(x, y) 0 #elif defined(STBI_FAILURE_USERMSG) - #define stbi__err(x,y) stbi__err(y) +#define stbi__err(x, y) stbi__err(y) #else - #define stbi__err(x,y) stbi__err(x) +#define stbi__err(x, y) stbi__err(x) #endif -#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) -#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpuc(x, y) \ + ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -STBIDEF void stbi_image_free(void *retval_from_stbi_load) -{ - STBI_FREE(retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load) { + STBI_FREE(retval_from_stbi_load); } #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else -static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, + stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ - ? stbi__vertically_flip_on_load_local \ - : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load \ + (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; - - #ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNG - if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_GIF - if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PSD - if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); - #else - STBI_NOTUSED(bpc); - #endif - #ifndef STBI_NO_PIC - if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); - #endif - - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } - #endif - - #ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s,x,y,comp,req_comp, ri); - #endif - - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); -} - -static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi_uc *reduced; - - reduced = (stbi_uc *) stbi__malloc(img_len); - if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - - STBI_FREE(orig); - return reduced; -} - -static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi__uint16 *enlarged; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = + 8; // default is 8 so most paths don't have to be changed + ri->channel_order = + STBI_ORDER_RGB; // all current input & output are this, but this is here + // so we can add BGR order + ri->num_channels = 0; - enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); - if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); +#ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) + return stbi__jpeg_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNG + if (stbi__png_test(s)) + return stbi__png_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) + return stbi__bmp_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_GIF + if (stbi__gif_test(s)) + return stbi__gif_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PSD + if (stbi__psd_test(s)) + return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); +#else + STBI_NOTUSED(bpc); +#endif +#ifndef STBI_NO_PIC + if (stbi__pic_test(s)) + return stbi__pic_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) + return stbi__pnm_load(s, x, y, comp, req_comp, ri); +#endif - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } +#endif - STBI_FREE(orig); - return enlarged; -} +#ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s, x, y, comp, req_comp, ri); +#endif -static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) -{ - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc *bytes = (stbi_uc *)image; - - for (row = 0; row < (h>>1); row++) { - stbi_uc *row0 = bytes + row*bytes_per_row; - stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } + return stbi__errpuc("unknown image type", + "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *)stbi__malloc(img_len); + if (reduced == NULL) + return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = + (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient + // approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); + if (enlarged == NULL) + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + + orig[i]); // replicate to high and low byte, + // maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, + int bytes_per_pixel) { + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h >> 1); row++) { + stbi_uc *row0 = bytes + row * bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1) * bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = + (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) -{ - int slice; - int slice_size = w * h * bytes_per_pixel; +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, + int bytes_per_pixel) { + int slice; + int slice_size = w * h * bytes_per_pixel; - stbi_uc *bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *) result; + return (unsigned char *)result; } -static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to + // keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *) result; + return (stbi__uint16 *)result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) -{ - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, + int req_comp) { + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar( + unsigned int cp, unsigned long flags, const char *str, int cbmb, + wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte( + unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, + char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) -{ - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input) { + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, + (int)bufferlen, NULL, NULL); } #endif -static FILE *stbi__fopen(char const *filename, char const *mode) -{ - FILE *f; +static FILE *stbi__fopen(char const *filename, char const *mode) { + FILE *f; #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, + sizeof(wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) - return 0; + if (0 == + MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) + return 0; #if _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f=0; + if (0 != fopen_s(&f, filename, mode)) + f = 0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; -} - - -STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - unsigned char *result; - if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; -} - -STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__uint16 *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - stbi__uint16 *result; - if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f,x,y,comp,req_comp); - fclose(f); - return result; -} - - -#endif //!STBI_NO_STDIO - -STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); -} - -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + return f; +} + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) + return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) + return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +#endif //! STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_mem(&s,buffer,len); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_mem(&s, buffer, len); - result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); - } + result = + (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices(result, *x, *y, *z, *comp); + } - return result; + return result; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *data; - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data,x,y,comp,req_comp); - return hdr_data; - } - #endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); -} - -STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__loadf_main(&s,x,y,comp,req_comp); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp) { + unsigned char *data; +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data, x, y, comp, req_comp); + return hdr_data; + } +#endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", + "Image not of any known type, or corrupt"); } -STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__loadf_main(&s,x,y,comp,req_comp); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -#ifndef STBI_NO_STDIO -STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - float *result; - FILE *f = stbi__fopen(filename, "rb"); - if (!f) return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_file(&s,f); - return stbi__loadf_main(&s,x,y,comp,req_comp); +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, + int req_comp) { + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) + return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_file(&s, f); + return stbi__loadf_main(&s, x, y, comp, req_comp); } #endif // !STBI_NO_STDIO @@ -1454,221 +1539,222 @@ STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_ // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; +#endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result=0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result = 0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE *f) -{ - #ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s,f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; - #else - STBI_NOTUSED(f); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) { +#ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s, f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; +#else + STBI_NOTUSED(f); + return 0; +#endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; +#endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; +static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { + stbi__l2h_gamma = gamma; +} +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { + stbi__l2h_scale = scale; +} #endif -static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; - -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } +static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { + stbi__h2l_gamma_i = 1 / gamma; +} +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { + stbi__h2l_scale_i = 1 / scale; +} ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum -{ - STBI__SCAN_load=0, - STBI__SCAN_type, - STBI__SCAN_header -}; - -static void stbi__refill_buffer(stbi__context *s) -{ - int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); - s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start+1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } -} - -stbi_inline static stbi_uc stbi__get8(stbi__context *s) -{ - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; -} - -#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; + +static void stbi__refill_buffer(stbi__context *s) { + int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context *s) -{ - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) + return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) + return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context *s, int n) -{ - if (n == 0) return; // already there! - if (n < 0) { +static void stbi__skip(stbi__context *s, int n) { + if (n == 0) + return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); return; - } - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; + } + } + s->img_buffer += n; } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && \ + defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) -{ - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) { + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; - memcpy(buffer, s->img_buffer, blen); + memcpy(buffer, s->img_buffer, blen); - count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); - res = (count == (n-blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } + count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); + res = (count == (n - blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } - if (s->img_buffer+n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + if (s->img_buffer + n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context *s) -{ - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) { + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context *s) -{ - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) { + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context *s) -{ - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) { + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context *s) -{ - stbi__uint32 z = stbi__get16le(s); - return z + (stbi__get16le(s) << 16); +static stbi__uint32 stbi__get32le(stbi__context *s) { + stbi__uint32 z = stbi__get16le(s); + return z + (stbi__get16le(s) << 16); } #endif -#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) \ + ((stbi_uc)((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1682,169 +1768,301 @@ static stbi__uint32 stbi__get32le(stbi__context *s) // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) -{ - return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi_uc stbi__compute_y(int r, int g, int b) { + return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - unsigned char *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - unsigned char *src = data + j * x * img_n ; - unsigned char *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + unsigned char *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + unsigned char *src = data + j * x * img_n; + unsigned char *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 255; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 255; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 255; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = 255; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) -{ - return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { + return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - stbi__uint16 *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - stbi__uint16 *src = data + j * x * img_n ; - stbi__uint16 *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + stbi__uint16 *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + stbi__uint16 *src = data + j * x * img_n; + stbi__uint16 *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 0xffff; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 0xffff; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 0xffff; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = 0xffff; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return (stbi__uint16 *)stbi__errpuc("unsupported", + "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) -{ - int i,k,n; - float *output; - if (!data) return NULL; - output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i=0; i < x*y; ++i) { - output[i*comp + n] = data[i*comp + n]/255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) { + int i, k, n; + float *output; + if (!data) + return NULL; + output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpf("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + output[i * comp + k] = + (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * + stbi__l2h_scale); + } + } + if (n < comp) { + for (i = 0; i < x * y; ++i) { + output[i * comp + n] = data[i * comp + n] / 255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int) (x)) -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) -{ - int i,k,n; - stbi_uc *output; - if (!data) return NULL; - output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - if (k < comp) { - float z = data[i*comp+k] * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int)(x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) { + int i, k, n; + stbi_uc *output; + if (!data) + return NULL; + output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, + stbi__h2l_gamma_i) * + 255 + + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + if (k < comp) { + float z = data[i * comp + k] * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1872,750 +2090,791 @@ static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache - -typedef struct -{ - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct { + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct -{ - stbi__context *s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; - -// sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; - -// definition of jpeg image component - struct - { - int id; - int h,v; - int tq; - int hd,ha; - int dc_pred; - - int x,y,w2,h2; - stbi_uc *data; - void *raw_data, *raw_coeff; - stbi_uc *linebuf; - short *coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; - - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop - - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; - - int scan_n, order[4]; - int restart_interval, todo; - -// kernels - void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); - stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +typedef struct { + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, int count, + int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman *h, int *count) -{ - int i,j,k=0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i=0; i < 16; ++i) - for (j=0; j < count[i]; ++j) - h->size[k++] = (stbi_uc) (i+1); - h->size[k] = 0; - - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for(j=1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16) (code++); - if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16-j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; - - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i=0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS-s); - int m = 1 << (FAST_BITS-s); - for (j=0; j < m; ++j) { - h->fast[c+j] = (stbi_uc) i; - } +static int stbi__build_huffman(stbi__huffman *h, int *count) { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) + for (j = 0; j < count[i]; ++j) + h->size[k++] = (stbi_uc)(i + 1); + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16)(code++); + if (code - 1 >= (1u << j)) + return stbi__err("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (stbi_uc)i; } - } - return 1; + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) -{ - int i; - for (i=0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; - - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); - } +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) + k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); } - } -} - -static void stbi__grow_buffer_unsafe(stbi__jpeg *j) -{ - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char) c; - j->nomore = 1; - return; - } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) { + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) + c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; +static const stbi__uint32 stbi__bmask[17] = { + 0, 1, 3, 7, 15, 31, 63, 127, 255, + 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) -{ - unsigned int temp; - int c,k; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } - - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k=FAST_BITS+1 ; ; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } - - if (k > j->code_bits) +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) return -1; - - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1;; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & + stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); - - sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB - k = stbi_lrot(j->code_buffer, n); - if (n < 0 || n >= (int) (sizeof(stbi__bmask)/sizeof(*stbi__bmask))) return 0; - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & ~sgn); +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) { + unsigned int k; + int sgn; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + + sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB + k = stbi_lrot(j->code_buffer, n); + if (n < 0 || n >= (int)(sizeof(stbi__bmask) / sizeof(*stbi__bmask))) + return 0; + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & ~sgn); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) -{ - unsigned int k; - if (j->code_bits < n) stbi__grow_buffer_unsafe(j); - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; -} - -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) -{ - unsigned int k; - if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) { + unsigned int k; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) { + unsigned int k; + if (j->code_bits < 1) + stbi__grow_buffer_unsafe(j); + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64+15] = -{ - 0, 1, 8, 16, 9, 2, 3, 10, - 17, 24, 32, 25, 18, 11, 4, 5, - 12, 19, 26, 33, 40, 48, 41, 34, - 27, 20, 13, 6, 7, 14, 21, 28, - 35, 42, 49, 56, 57, 50, 43, 36, - 29, 22, 15, 23, 30, 37, 44, 51, - 58, 59, 52, 45, 38, 31, 39, 46, - 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, - 63, 63, 63, 63, 63, 63, 63 -}; +static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { + 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, + 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, + 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) -{ - int diff,dc,k; - int t; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - - // 0 all the ac values now so we can do it 32-bits at a time - memset(data,0,64*sizeof(data[0])); - - diff = t ? stbi__extend_receive(j, t) : 0; - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc * dequant[0]); - - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) * dequant[zig]); +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, stbi__huffman *hac, + stbi__int16 *fac, int b, + stbi__uint16 *dequant) { + int diff, dc, k; + int t; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) + break; // end block + k += 16; } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (rs != 0xf0) break; // end block - k += 16; - } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); - } + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); } - } while (k < 64); - return 1; -} - -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) -{ - int diff,dc; - int t; - if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data,0,64*sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t == -1) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; - - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc << j->succ_low); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short) (1 << j->succ_low); - } - return 1; + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, int b) { + int diff, dc; + int t; + if (j->spec_end != 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t == -1) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc << j->succ_low); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; } // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) -{ - int k; - if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->succ_high == 0) { - int shift = j->succ_low; +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], + stbi__huffman *hac, + stbi__int16 *fac) { + int k; + if (j->spec_start == 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } - if (j->eob_run) { - --j->eob_run; - return 1; + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) << shift); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) << shift); + } } - + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { k = j->spec_start; do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) << shift); - } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) << shift); - } - } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients - - short bit = (short) (1 << j->succ_low); - - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short *p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } - } else { - k = j->spec_start; - do { - int r,s; - int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } + int r, s; + int rs = stbi__jpeg_huff_decode( + j, hac); // @OPTIMIZE see if we can use the fast path here, + // advance-by-r is so slow, eh + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) + return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } - // advance by r - while (k <= j->spec_end) { - short *p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short) s; - break; - } - --r; - } + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short)s; + break; } - } while (k <= j->spec_end); - } - } - return 1; + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) -{ - // trick to use a single test to catch both cases - if ((unsigned int) x > 255) { - if (x < 0) return 0; - if (x > 255) return 255; - } - return (stbi_uc) x; +stbi_inline static stbi_uc stbi__clamp(int x) { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) + return 0; + if (x > 255) + return 255; + } + return (stbi_uc)x; } -#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) -#define stbi__fsh(x) ((x) * 4096) +#define stbi__f2f(x) ((int)(((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ - int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2+p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3*stbi__f2f(-1.847759065f); \ - t3 = p1 + p2*stbi__f2f( 0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2+p3); \ - t1 = stbi__fsh(p2-p3); \ - x0 = t0+t3; \ - x3 = t0-t3; \ - x1 = t1+t2; \ - x2 = t1-t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0+t2; \ - p4 = t1+t3; \ - p1 = t0+t3; \ - p2 = t1+t2; \ - p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ - t0 = t0*stbi__f2f( 0.298631336f); \ - t1 = t1*stbi__f2f( 2.053119869f); \ - t2 = t2*stbi__f2f( 3.072711026f); \ - t3 = t3*stbi__f2f( 1.501321110f); \ - p1 = p5 + p1*stbi__f2f(-0.899976223f); \ - p2 = p5 + p2*stbi__f2f(-2.562915447f); \ - p3 = p3*stbi__f2f(-1.961570560f); \ - p4 = p4*stbi__f2f(-0.390180644f); \ - t3 += p1+p4; \ - t2 += p2+p3; \ - t1 += p2+p4; \ - t0 += p1+p3; - -static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) -{ - int i,val[64],*v=val; - stbi_uc *o; - short *d = data; - - // columns - for (i=0; i < 8; ++i,++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 - && d[40]==0 && d[48]==0 && d[56]==0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0]*4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; x1 += 512; x2 += 512; x3 += 512; - v[ 0] = (x0+t3) >> 10; - v[56] = (x0-t3) >> 10; - v[ 8] = (x1+t2) >> 10; - v[48] = (x1-t2) >> 10; - v[16] = (x2+t1) >> 10; - v[40] = (x2-t1) >> 10; - v[24] = (x3+t0) >> 10; - v[32] = (x3-t0) >> 10; - } - } - - for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128<<17); - x1 += 65536 + (128<<17); - x2 += 65536 + (128<<17); - x3 += 65536 + (128<<17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0+t3) >> 17); - o[7] = stbi__clamp((x0-t3) >> 17); - o[1] = stbi__clamp((x1+t2) >> 17); - o[6] = stbi__clamp((x1-t2) >> 17); - o[2] = stbi__clamp((x2+t1) >> 17); - o[5] = stbi__clamp((x2-t1) >> 17); - o[3] = stbi__clamp((x3+t0) >> 17); - o[4] = stbi__clamp((x3-t0) >> 17); - } +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ + t0 += p1 + p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) { + int i, val[64], *v = val; + stbi_uc *o; + short *d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && + d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; + x1 += 512; + x2 += 512; + x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0 + t3) >> 17); + o[7] = stbi__clamp((x0 - t3) >> 17); + o[1] = stbi__clamp((x1 + t2) >> 17); + o[6] = stbi__clamp((x1 - t2) >> 17); + o[2] = stbi__clamp((x2 + t1) >> 17); + o[5] = stbi__clamp((x2 - t1) >> 17); + o[3] = stbi__clamp((x3 + t0) >> 17); + o[4] = stbi__clamp((x3 - t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; - - // dot product constant: even elems=x, odd elems=y - #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) - - // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) - // out(1) = c1[even]*x + c1[odd]*y - #define dct_rot(out0,out1, x,y,c0,c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) - - // out = in << 12 (in 16-bit, out 32-bit) - #define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - - // wide add - #define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - - // wide sub - #define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) - - // butterfly a/b, add bias, then shift by "s" and pack - #define dct_bfly32o(out0, out1, a,b,bias,s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } - - // 8-bit interleave step (for transposes) - #define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) - - // 16-bit interleave step (for transposes) - #define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) - - #define dct_pass(bias,shift) \ - { \ - /* even part */ \ - dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ - dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0,row7, x0,x7,bias,shift); \ - dct_bfly32o(row1,row6, x1,x6,bias,shift); \ - dct_bfly32o(row2,row5, x2,x5,bias,shift); \ - dct_bfly32o(row3,row4, x3,x4,bias,shift); \ - } +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + +// dot product constant: even elems=x, odd elems=y +#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + +// out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = \ + _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = \ + _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); - - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); - - // load - row0 = _mm_load_si128((const __m128i *) (data + 0*8)); - row1 = _mm_load_si128((const __m128i *) (data + 1*8)); - row2 = _mm_load_si128((const __m128i *) (data + 2*8)); - row3 = _mm_load_si128((const __m128i *) (data + 3*8)); - row4 = _mm_load_si128((const __m128i *) (data + 4*8)); - row5 = _mm_load_si128((const __m128i *) (data + 5*8)); - row6 = _mm_load_si128((const __m128i *) (data + 6*8)); - row7 = _mm_load_si128((const __m128i *) (data + 7*8)); - - // column pass - dct_pass(bias_0, 10); - - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); - - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); - - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } - - // row pass - dct_pass(bias_1, 17); - - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); - - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... - - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... - - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... +// wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - // store - _mm_storel_epi64((__m128i *) out, p0); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p2); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p1); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p3); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); - } +// wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + +// butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = \ + _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = \ + _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + +// 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + +// 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), + stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), + stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), + stbi__f2f(1.175875602f)); + __m128i rot1_1 = + dct_const(stbi__f2f(1.175875602f), + stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), + stbi__f2f(-1.961570560f)); + __m128i rot2_1 = + dct_const(stbi__f2f(-1.961570560f), + stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), + stbi__f2f(-0.390180644f)); + __m128i rot3_1 = + dct_const(stbi__f2f(-0.390180644f), + stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); + row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); + row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); + row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); + row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); + row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); + row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); + row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *)out, p0); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p2); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p1); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p3); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2634,198 +2893,236 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); - -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) - -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) - -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ - dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ - dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ - dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ - } - - // load - row0 = vld1q_s16(data + 0*8); - row1 = vld1q_s16(data + 1*8); - row2 = vld1q_s16(data + 2*8); - row3 = vld1q_s16(data + 3*8); - row4 = vld1q_s16(data + 4*8); - row5 = vld1q_s16(data + 5*8); - row6 = vld1q_s16(data + 6*8); - row7 = vld1q_s16(data + 7*8); - - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - - // column pass - dct_pass(vrshrn_n_s32, 10); - - // 16bit 8x8 transpose - { +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ + } + + // load + row0 = vld1q_s16(data + 0 * 8); + row1 = vld1q_s16(data + 1 * 8); + row2 = vld1q_s16(data + 2 * 8); + row3 = vld1q_s16(data + 3 * 8); + row4 = vld1q_s16(data + 4 * 8); + row5 = vld1q_s16(data + 5 * 8); + row6 = vld1q_s16(data + 6 * 8); + row7 = vld1q_s16(data + 7 * 8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } -#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } - - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); - - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); - - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = \ + vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ + } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } - - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); - - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } -#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } - - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! - - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); - - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); - - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); - - // store - vst1_u8(out, p0); out += out_stride; - vst1_u8(out, p1); out += out_stride; - vst1_u8(out, p2); out += out_stride; - vst1_u8(out, p3); out += out_stride; - vst1_u8(out, p4); out += out_stride; - vst1_u8(out, p5); out += out_stride; - vst1_u8(out, p6); out += out_stride; - vst1_u8(out, p7); + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ + } +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ + } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); + out += out_stride; + vst1_u8(out, p1); + out += out_stride; + vst1_u8(out, p2); + out += out_stride; + vst1_u8(out, p3); + out += out_stride; + vst1_u8(out, p4); + out += out_stride; + vst1_u8(out, p5); + out += out_stride; + vst1_u8(out, p6); + out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -2838,1132 +3135,1274 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg *j) -{ - stbi_uc x; - if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } - x = stbi__get8(j->s); - if (x != 0xff) return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) { + stbi_uc x; + if (j->marker != STBI__MARKER_none) { + x = j->marker; + j->marker = STBI__MARKER_none; + return x; + } + x = stbi__get8(j->s); + if (x != 0xff) + return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg *j) -{ - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels -} - -static int stbi__parse_entropy_coded_data(stbi__jpeg *z) -{ - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i,j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - STBI_SIMD_ALIGN(short, data[64]); - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x)*8; - int y2 = (j*z->img_comp[n].v + y)*8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; +static void stbi__jpeg_reset(stbi__jpeg *j) { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = + j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) { + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } else { - if (z->scan_n == 1) { - int i,j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x); - int y2 = (j*z->img_comp[n].v + y); - short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } + return 1; + } else { // interleaved + int i, j, k, x, y; + STBI_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, + z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + + z->img_comp[n].w2 * y2 + x2, + z->img_comp[n].w2, data); + } } - } - return 1; + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) -{ - int i; - for (i=0; i < 64; ++i) - data[i] *= dequant[i]; -} - -static void stbi__jpeg_finish(stbi__jpeg *z) -{ - if (z->progressive) { - // dequantize and idct the data - int i,j,n; - for (n=0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - } - } + return 1; + } + } else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], + z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static int stbi__process_marker(stbi__jpeg *z, int m) -{ - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker","Corrupt JPEG"); - - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; - - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s)-2; - while (L > 0) { - int q = stbi__get8(z->s); - int p = q >> 4, sixteen = (p != 0); - int t = q & 15,i; - if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); - if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - - for (i=0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); - L -= (sixteen ? 129 : 65); - } - return L==0; - - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s)-2; - while (L > 0) { - stbi_uc *v; - int sizes[16],i,n=0; - int q = stbi__get8(z->s); - int tc = q >> 4; - int th = q & 15; - if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); - for (i=0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; - } - L -= 17; - if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; - v = z->huff_dc[th].values; - } else { - if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; - v = z->huff_ac[th].values; + return 1; + } else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } } - for (i=0; i < n; ++i) - v[i] = stbi__get8(z->s); - if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); - L -= n; - } - return L==0; - } - - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len","Corrupt JPEG"); - else - return stbi__err("bad APP len","Corrupt JPEG"); + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - L -= 2; - - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J','F','I','F','\0'}; - int ok = 1; - int i; - for (i=0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; - int ok = 1; - int i; - for (i=0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + } } + } + } +} - stbi__skip(z->s, L); - return 1; - } +static int stbi__process_marker(stbi__jpeg *z, int m) { + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker", "Corrupt JPEG"); - return stbi__err("unknown marker","Corrupt JPEG"); -} + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) + return stbi__err("bad DRI len", "Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; -// after we see SOS -static int stbi__process_scan_header(stbi__jpeg *z) -{ - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); - if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); - for (i=0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s) - 2; + while (L > 0) { int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) return 0; // no match - z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); - z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); - z->order[i] = which; - } - - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) + return stbi__err("bad DQT type", "Corrupt JPEG"); + if (t > 3) + return stbi__err("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = + (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + stbi_uc *v; + int sizes[16], i, n = 0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) + return stbi__err("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc + th, sizes)) + return 0; + v = z->huff_dc[th].values; } else { - if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); - z->spec_end = 63; + if (!stbi__build_huffman(z->huff_ac + th, sizes)) + return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len", "Corrupt JPEG"); + else + return stbi__err("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; } - } + } + + stbi__skip(z->s, L); + return 1; + } - return 1; + return stbi__err("unknown marker", "Corrupt JPEG"); } -static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) -{ - int i; - for (i=0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) { + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) + return stbi__err("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) + return stbi__err("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) + return 0; // no match + z->img_comp[which].hd = q >> 4; + if (z->img_comp[which].hd > 3) + return stbi__err("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; + if (z->img_comp[which].ha > 3) + return stbi__err("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || + z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; } -static int stbi__process_frame_header(stbi__jpeg *z, int scan) -{ - stbi__context *s = z->s; - int Lf,p,i,q, h_max=1,v_max=1,c; - Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG - p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); - s->img_n = c; - for (i=0; i < c; ++i) { +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } - - if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); - - z->rgb = 0; - for (i=0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = { 'R', 'G', 'B' }; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); - z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); - } - - if (scan != STBI__SCAN_load) return 1; - - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); - - for (i=0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; - } - - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; - - for (i=0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); - } - } + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) { + stbi__context *s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = stbi__get16be(s); + if (Lf < 11) + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + p = stbi__get8(s); + if (p != 8) + return stbi__err("only 8-bit", + "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); + if (s->img_y == 0) + return stbi__err( + "no header height", + "JPEG format not supported: delayed height"); // Legal, but we don't + // handle it--but neither + // does IJG + s->img_x = stbi__get16be(s); + if (s->img_x == 0) + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) + return stbi__err("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) + return stbi__err("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = {'R', 'G', 'B'}; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); + if (!z->img_comp[i].h || z->img_comp[i].h > 4) + return stbi__err("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; + if (!z->img_comp[i].v || z->img_comp[i].v > 4) + return stbi__err("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); + if (z->img_comp[i].tq > 3) + return stbi__err("bad TQ", "Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) + return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) + return stbi__err("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) + h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) + v_max = z->img_comp[i].v; + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked + // earlier) so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = + stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i + 1, + stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = + (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3( + z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components( + z, i + 1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = + (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } - return 1; + return 1; } // use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) - -#define stbi__SOF_progressive(x) ((x) == 0xc2) - -static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) -{ - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); - if (scan == STBI__SCAN_type) return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z,m)) return 0; +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) + return stbi__err("no SOI", "Corrupt JPEG"); + if (scan == STBI__SCAN_type) + return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z, m)) + return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) + return stbi__err("no SOF", "Corrupt JPEG"); m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) return 0; - return 1; + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) + return 0; + return 1; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg *j) -{ - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) return 0; - if (!stbi__parse_entropy_coded_data(j)) return 0; - if (j->marker == STBI__MARKER_none ) { - // handle 0s at the end of image data from IP Kamera 9060 - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - if (x == 255) { - j->marker = stbi__get8(j->s); - break; - } - } - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); - } else { - if (!stbi__process_marker(j, m)) return 0; +static int stbi__decode_jpeg_image(stbi__jpeg *j) { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) + return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) + return 0; + if (!stbi__parse_entropy_coded_data(j)) + return 0; + if (j->marker == STBI__MARKER_none) { + // handle 0s at the end of image data from IP Kamera 9060 + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + if (x == 255) { + j->marker = stbi__get8(j->s); + break; + } + } + // if we reach eof without hitting a marker, stbi__get_marker() below + // will fail and we'll eventually return 0 } - m = stbi__get_marker(j); - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) + return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) + return stbi__err("bad DNL height", "Corrupt JPEG"); + } else { + if (!stbi__process_marker(j, m)) + return 0; + } + m = stbi__get_marker(j); + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, - int w, int hs); - -#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) - -static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; -} - -static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i=0; i < w; ++i) - out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); - return out; + int w, int hs); + +#define stbi__div4(x) ((stbi_uc)((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, + int w, int hs) { + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc *stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc *stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = stbi__div4(n + input[i - 1]); + out[i * 2 + 1] = stbi__div4(n + input[i + 1]); + } + out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc)((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = stbi__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; } -static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples horizontally for every one in input - int i; - stbi_uc *input = in_near; - - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } - - out[0] = input[0]; - out[1] = stbi__div4(input[0]*3 + input[1] + 2); - for (i=1; i < w-1; ++i) { - int n = 3*input[i]+2; - out[i*2+0] = stbi__div4(n+input[i-1]); - out[i*2+1] = stbi__div4(n+input[i+1]); - } - out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); - out[i*2+1] = input[w-1]; - - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); - - return out; -} +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = + _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *)(out + i * 2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = + vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif -#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } -static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i,t0,t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - t1 = 3*in_near[0] + in_far[0]; - out[0] = stbi__div4(t1+2); - for (i=1; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } +#endif -#if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i=0,t0,t1; - - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } - - t1 = 3*in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w-1) & ~7); i += 8) { -#if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); - - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); - - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *) (out + i*2), outv); -#elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); - - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i*2, o); -#endif - - // "previous" value for next iter - t1 = 3*in_near[i+7] + in_far[i+7]; - } - - t0 = t1; - t1 = 3*in_near[i] + in_far[i]; - out[i*2] = stbi__div16(3*t1 + t0 + 8); - - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); - - STBI_NOTUSED(hs); - - return out; -} -#endif - -static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // resample with nearest-neighbor - int i,j; - STBI_NOTUSED(in_far); - for (i=0; i < w; ++i) - for (j=0; j < hs; ++j) - out[i*hs+j] = in_near[i]; - return out; +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // resample with nearest-neighbor + int i, j; + STBI_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) -{ - int i; - for (i=0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int)((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, + int count, int step) { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) -{ - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, + stbi_uc const *pcb, stbi_uc const *pcr, + int count, int step) { + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); - __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); - __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); - __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); - __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); - __m128i xw = _mm_set1_epi16(255); // alpha channel - - for (; i+7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); - - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); - - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); - - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); - - // store - _mm_storeu_si128((__m128i *) (out + 0), o0); - _mm_storeu_si128((__m128i *) (out + 16), o1); - out += 32; - } - } + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *)(out + 0), o0); + _mm_storeu_si128((__m128i *)(out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); - int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); - int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); - int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); - - for (; i+7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); - - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); - - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); - - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8*4; - } - } -#endif - - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + cr * -stbi__float2fixed(0.71414f) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg *j) -{ - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) { + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg *j) -{ - stbi__free_jpeg_components(j, j->s->img_n, 0); +static void stbi__cleanup_jpeg(stbi__jpeg *j) { + stbi__free_jpeg_components(j, j->s->img_n, 0); } -typedef struct -{ - resample_row_func resample; - stbi_uc *line0,*line1; - int hs,vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +typedef struct { + resample_row_func resample; + stbi_uc *line0, *line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) -{ - unsigned int t = x*y + 128; - return (stbi_uc) ((t + (t >>8)) >> 8); -} - -static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) -{ - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe - - // validate req_comp - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } - - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; - - // resample and color-convert - { - int k; - unsigned int i,j; - stbi_uc *output; - stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; - - stbi__resample res_comp[4]; - - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs-1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; - else r->resample = stbi__resample_row_generic; +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { + unsigned int t = x * y + 128; + return (stbi_uc)((t + (t >> 8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, + int *comp, int req_comp) { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && + (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // resample and color-convert + { + int k; + unsigned int i, j; + stbi_uc *output; + stbi_uc *coutput[4] = {NULL, NULL, NULL, NULL}; + + stbi__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); } - // can't error after this so, this is safe - output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - // now go ahead and resample - for (j=0; j < z->s->img_y; ++j) { - stbi_uc *out = output + n * z->s->img_x * j; - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, - y_bot ? r->line1 : r->line0, - y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) + r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) + r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) + r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) + r->resample = z->resample_row_hv_2_kernel; + else + r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = + r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; } - } - if (n >= 3) { - stbi_uc *y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i=0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } - } else { - if (is_rgb) { - if (n == 1) - for (i=0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i=0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc *y = coutput[0]; - if (n == 1) - for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; - else - for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; } - } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { + *out++ = y[i]; + *out++ = 255; + } + } } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } -} - -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - unsigned char* result; - stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x,y,comp,req_comp); - STBI_FREE(j); - return result; -} - -static int stbi__jpeg_test(stbi__context *s) -{ - int r; - stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; -} - -static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) -{ - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind( j->s ); - return 0; - } - if (x) *x = j->s->img_x; - if (y) *y = j->s->img_y; - if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; -} - -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) -{ - int result; - stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) + *comp = + z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + unsigned char *result; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) { + int r; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) { + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind(j->s); + return 0; + } + if (x) + *x = j->s->img_x; + if (y) + *y = j->s->img_y; + if (comp) + *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) { + int result; + stbi__jpeg *j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -3977,83 +4416,81 @@ static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct -{ - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[288]; - stbi__uint16 value[288]; +typedef struct { + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[288]; + stbi__uint16 value[288]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) -{ - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); +stbi_inline static int stbi__bitreverse16(int n) { + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); return n; } -stbi_inline static int stbi__bit_reverse(int v, int bits) -{ - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16-bits); -} - -static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) -{ - int i,k=0; - int code, next_code[16], sizes[17]; - - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i=0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i=1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i=1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16) code; - z->firstsymbol[i] = (stbi__uint16) k; - code = (code + sizes[i]); - if (sizes[i]) - if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); - z->maxcode[i] = code << (16-i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i=0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); - z->size [c] = (stbi_uc ) s; - z->value[c] = (stbi__uint16) i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s],s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } - } - ++next_code[s]; +stbi_inline static int stbi__bit_reverse(int v, int bits) { + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16 - bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, + int num) { + int i, k = 0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i = 0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i = 1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i = 1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16)code; + z->firstsymbol[i] = (stbi__uint16)k; + code = (code + sizes[i]); + if (sizes[i]) + if (code - 1 >= (1 << i)) + return stbi__err("bad codelengths", "Corrupt PNG"); + z->maxcode[i] = code << (16 - i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i = 0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); + z->size[c] = (stbi_uc)s; + z->value[c] = (stbi__uint16)i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s], s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } } - } - return 1; + ++next_code[s]; + } + } + return 1; } // zlib-from-memory implementation for PNG reading @@ -4062,277 +4499,313 @@ static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int // we require PNG read all the IDATs and combine them into a single // memory buffer -typedef struct -{ - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; +typedef struct { + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; - char *zout; - char *zout_start; - char *zout_end; - int z_expandable; + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; - stbi__zhuffman z_length, z_distance; + stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf *z) -{ - return (z->zbuffer >= z->zbuffer_end); -} - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) -{ - return stbi__zeof(z) ? 0 : *z->zbuffer++; -} - -static void stbi__fill_bits(stbi__zbuf *z) -{ - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); +stbi_inline static int stbi__zeof(stbi__zbuf *z) { + return (z->zbuffer >= z->zbuffer_end); } -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) -{ - unsigned int k; - if (z->num_bits < n) stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) { + return stbi__zeof(z) ? 0 : *z->zbuffer++; } -static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s,k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s=STBI__ZFAST_BITS+1; ; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) return -1; // invalid code! - // code size is s, so: - b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= sizeof (z->size)) return -1; // some data was corrupt somewhere! - if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes -{ - char *q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); - cur = (unsigned int) (z->zout - z->zout_start); - limit = old_limit = (unsigned) (z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; +static void stbi__fill_bits(stbi__zbuf *z) { + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) { + unsigned int k; + if (z->num_bits < n) + stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s, k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s = STBI__ZFAST_BITS + 1;; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) + return -1; // invalid code! + // code size is s, so: + b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= sizeof(z->size)) + return -1; // some data was corrupt somewhere! + if (z->size[b] != s) + return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, + int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) + return stbi__err("output buffer limit", "Corrupt PNG"); + cur = (unsigned int)(z->zout - z->zout_start); + limit = old_limit = (unsigned)(z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned)n) + return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if (limit > UINT_MAX / 2) + return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) + return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; } static const int stbi__zlength_base[31] = { - 3,4,5,6,7,8,9,10,11,13, - 15,17,19,23,27,31,35,43,51,59, - 67,83,99,115,131,163,195,227,258,0,0 }; - -static const int stbi__zlength_extra[31]= -{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; - -static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, -257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; - -static const int stbi__zdist_extra[32] = -{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; - -static int stbi__parse_huffman_block(stbi__zbuf *a) -{ - char *zout = a->zout; - for(;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) return 0; - zout = a->zout; - } - *zout++ = (char) z; + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + +static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, + 4, 4, 5, 5, 5, 5, 0, 0, 0}; + +static const int stbi__zdist_base[32] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; + +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) { + char *zout = a->zout; + for (;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) + return stbi__err("bad huffman code", + "Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) + return 0; + zout = a->zout; + } + *zout++ = (char)z; + } else { + stbi_uc *p; + int len, dist; + if (z == 256) { + a->zout = zout; + return 1; + } + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) + len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0) + return stbi__err("bad huffman code", "Corrupt PNG"); + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) + dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) + return stbi__err("bad dist", "Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) + return 0; + zout = a->zout; + } + p = (stbi_uc *)(zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { + do + *zout++ = v; + while (--len); + } } else { - stbi_uc *p; - int len,dist; - if (z == 256) { - a->zout = zout; - return 1; - } - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) return 0; - zout = a->zout; - } - p = (stbi_uc *) (zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { do *zout++ = v; while (--len); } - } else { - if (len) { do *zout++ = *p++; while (--len); } - } + if (len) { + do + *zout++ = *p++; + while (--len); + } } - } -} - -static int stbi__compute_huffman_codes(stbi__zbuf *a) -{ - static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286+32+137];//padding for maximum single op - stbi_uc codelength_sizes[19]; - int i,n; - - int hlit = stbi__zreceive(a,5) + 257; - int hdist = stbi__zreceive(a,5) + 1; - int hclen = stbi__zreceive(a,4) + 4; - int ntot = hlit + hdist; - - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i=0; i < hclen; ++i) { - int s = stbi__zreceive(a,3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; - - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc) c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a,2)+3; - if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n-1]; - } else if (c == 17) { - c = stbi__zreceive(a,3)+3; - } else if (c == 18) { - c = stbi__zreceive(a,7)+11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); - } - if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes+n, fill, c); - n += c; + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) { + static const stbi_uc length_dezigzag[19] = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc codelength_sizes[19]; + int i, n; + + int hlit = stbi__zreceive(a, 5) + 257; + int hdist = stbi__zreceive(a, 5) + 1; + int hclen = stbi__zreceive(a, 4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i = 0; i < hclen; ++i) { + int s = stbi__zreceive(a, 3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) + return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc)c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a, 2) + 3; + if (n == 0) + return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n - 1]; + } else if (c == 17) { + c = stbi__zreceive(a, 3) + 3; + } else if (c == 18) { + c = stbi__zreceive(a, 7) + 11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); } - } - if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; - return 1; -} - -static int stbi__parse_uncompressed_block(stbi__zbuf *a) -{ - stbi_uc header[4]; - int len,nlen,k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; -} - -static int stbi__parse_zlib_header(stbi__zbuf *a) -{ - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; -} - -static const stbi_uc stbi__zdefault_length[288] = -{ - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 -}; -static const stbi_uc stbi__zdefault_distance[32] = -{ - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 -}; + if (ntot - n < c) + return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes + n, fill, c); + n += c; + } + } + if (n != ntot) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) + return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) { + stbi_uc header[4]; + int len, nlen, k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = + (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) + return stbi__err("zlib corrupt", "Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) + return stbi__err("zlib corrupt", "Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) + return stbi__err("read past buffer", "Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) + return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) { + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if ((cmf * 256 + flg) % 31 != 0) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if (flg & 32) + return stbi__err("no preset dict", + "Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) + return stbi__err("bad compression", + "Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[288] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; +static const stbi_uc stbi__zdefault_distance[32] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; /* Init algorithm: { @@ -4346,117 +4819,131 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) -{ - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a,1); - type = stbi__zreceive(a,2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) return 0; - } else if (type == 3) { - return 0; +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) { + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) + return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a, 1); + type = stbi__zreceive(a, 2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) + return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, 288)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) + return 0; } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , 288)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; - } else { - if (!stbi__compute_huffman_codes(a)) return 0; - } - if (!stbi__parse_huffman_block(a)) return 0; + if (!stbi__compute_huffman_codes(a)) + return 0; } - } while (!final); - return 1; -} - -static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) -{ - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; - - return stbi__parse_zlib(a, parse_header); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) -{ - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int) (a.zout - a.zout_start); - else - return -1; -} - -STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(16384); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer+len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int) (a.zout - a.zout_start); - else - return -1; + if (!stbi__parse_huffman_block(a)) + return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, + int parse_header) { + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, + int *outlen) { + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + char const *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int)(a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, + int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(16384); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int)(a.zout - a.zout_start); + else + return -1; } #endif @@ -4471,1083 +4958,1312 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct -{ - stbi__uint32 length; - stbi__uint32 type; +typedef struct { + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) -{ - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) { + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context *s) -{ - static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; - int i; - for (i=0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) { + static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + int i; + for (i = 0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) + return stbi__err("bad png sig", "Not a PNG"); + return 1; } -typedef struct -{ - stbi__context *s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct { + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; - enum { - STBI__F_none=0, - STBI__F_sub=1, - STBI__F_up=2, - STBI__F_avg=3, - STBI__F_paeth=4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none = 0, + STBI__F_sub = 1, + STBI__F_up = 2, + STBI__F_avg = 3, + STBI__F_paeth = 4, + // synthetic filters used for first scanline to avoid needing a dummy row of + // 0s + STBI__F_avg_first, + STBI__F_paeth_first }; -static stbi_uc first_row_filter[5] = -{ - STBI__F_none, - STBI__F_sub, - STBI__F_none, - STBI__F_avg_first, - STBI__F_paeth_first -}; +static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, + STBI__F_avg_first, STBI__F_paeth_first}; -static int stbi__paeth(int a, int b, int c) -{ - int p = a + b - c; - int pa = abs(p-a); - int pb = abs(p-b); - int pc = abs(p-c); - if (pa <= pb && pa <= pc) return a; - if (pb <= pc) return b; - return c; +static int stbi__paeth(int a, int b, int c) { + int p = a + b - c; + int pa = abs(p - a); + int pb = abs(p - b); + int pc = abs(p - c); + if (pa <= pb && pa <= pc) + return a; + if (pb <= pc) + return b; + return c; } -static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; +static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, + 0, 0, 0, 0x01}; // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) -{ - int bytes = (depth == 16? 2 : 1); - stbi__context *s = a->s; - stbi__uint32 i,j,stride = x*out_n*bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later - - int output_bytes = out_n*bytes; - int filter_bytes = img_n*bytes; - int width = x; - - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); - a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) return stbi__err("outofmem", "Out of memory"); - - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; - - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); - - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *prior; - int filter = *raw++; - - if (filter > 4) - return stbi__err("invalid filter","Corrupt PNG"); - - if (depth < 8) { - if (img_width_bytes > x) return stbi__err("invalid width","Corrupt PNG"); - cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - - // if first row, use special filter that doesn't sample previous row - if (j == 0) filter = first_row_filter[filter]; - - // handle first byte explicitly - for (k=0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none : cur[k] = raw[k]; break; - case STBI__F_sub : cur[k] = raw[k]; break; - case STBI__F_up : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break; - case STBI__F_avg : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break; - case STBI__F_paeth : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break; - case STBI__F_avg_first : cur[k] = raw[k]; break; - case STBI__F_paeth_first: cur[k] = raw[k]; break; - } +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, + stbi__uint32 raw_len, int out_n, + stbi__uint32 x, stbi__uint32 y, int depth, + int color) { + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i, j, stride = x * out_n * bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n * bytes; + int filter_bytes = img_n * bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); + a->out = (stbi_uc *)stbi__malloc_mad3( + x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) + return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) + return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on + // non-interlaced PNGs, but issue #276 reported a PNG in the wild that had + // extra data at the end (all zeros), so just check for raw_len < img_len + // always. + if (raw_len < img_len) + return stbi__err("not enough pixels", "Corrupt PNG"); + + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter", "Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) + return stbi__err("invalid width", "Corrupt PNG"); + cur += + x * out_n - img_width_bytes; // store output to the rightmost img_len + // bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = + cur - + stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) + filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k = 0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none: + cur[k] = raw[k]; + break; + case STBI__F_sub: + cur[k] = raw[k]; + break; + case STBI__F_up: + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); + break; + case STBI__F_paeth: + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); + break; + case STBI__F_avg_first: + cur[k] = raw[k]; + break; + case STBI__F_paeth_first: + cur[k] = raw[k]; + break; } + } - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes+1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1)*filter_bytes; - #define STBI__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: memcpy(cur, raw, nk); break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break; - } - #undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n+1 == out_n); - #define STBI__CASE(f) \ - case f: \ - for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break; - } - #undef STBI__CASE - - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride*j; // start at the beginning of the row again - for (i=0; i < x; ++i,cur+=output_bytes) { - cur[filter_bytes+1] = 255; - } - } + // this is a little gross, so that we don't switch per-pixel or + // per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1) * filter_bytes; +#define STBI__CASE(f) \ + case f: \ + for (k = 0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: + memcpy(cur, raw, nk); + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - filter_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - filter_bytes], prior[k], + prior[k - filter_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); + } + break; } - } - - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *in = a->out + stride*j + x*out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit - // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k=x*img_n; k >= 2; k-=2, ++in) { - *cur++ = scale * ((*in >> 4) ); - *cur++ = scale * ((*in ) & 0x0f); - } - if (k > 0) *cur++ = scale * ((*in >> 4) ); - } else if (depth == 2) { - for (k=x*img_n; k >= 4; k-=4, ++in) { - *cur++ = scale * ((*in >> 6) ); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in ) & 0x03); - } - if (k > 0) *cur++ = scale * ((*in >> 6) ); - if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k=x*img_n; k >= 8; k-=8, ++in) { - *cur++ = scale * ((*in >> 7) ); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in ) & 0x01); - } - if (k > 0) *cur++ = scale * ((*in >> 7) ); - if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride*j; - if (img_n == 1) { - for (q=x-1; q >= 0; --q) { - cur[q*2+1] = 255; - cur[q*2+0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q=x-1; q >= 0; --q) { - cur[q*4+3] = 255; - cur[q*4+2] = cur[q*3+2]; - cur[q*4+1] = cur[q*3+1]; - cur[q*4+0] = cur[q*3+0]; - } - } - } +#undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n + 1 == out_n); +#define STBI__CASE(f) \ + case f: \ + for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, \ + cur += output_bytes, prior += output_bytes) \ + for (k = 0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { + cur[k] = raw[k]; + } + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - output_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - output_bytes], prior[k], + prior[k - output_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); + } + break; } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc *cur = a->out; - stbi__uint16 *cur16 = (stbi__uint16*)cur; - - for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) { - *cur16 = (cur[0] << 8) | cur[1]; +#undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride * j; // start at the beginning of the row again + for (i = 0; i < x; ++i, cur += output_bytes) { + cur[filter_bytes + 1] = 255; + } } - } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *in = a->out + stride * j + x * out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common + // 8-bit path optimal at minimal cost for 1/2/4-bit png guarante byte + // alignment, if width is not multiple of 8/4/2 we'll decode dummy + // trailing data that will be skipped in the later loop + stbi_uc scale = (color == 0) + ? stbi__depth_scale_table[depth] + : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than + // desired. we can allocate enough data that this never writes out of + // memory, but it could also overwrite the next scanline. can it overwrite + // non-empty data on the next scanline? yes, consider 1-pixel-wide + // scanlines with 1-bit-per-pixel. so we need to explicitly clamp the + // final ones + + if (depth == 4) { + for (k = x * img_n; k >= 2; k -= 2, ++in) { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } else if (depth == 2) { + for (k = x * img_n; k >= 4; k -= 4, ++in) { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k = x * img_n; k >= 8; k -= 8, ++in) { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) + *cur++ = scale * ((*in >> 7)); + if (k > 1) + *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) + *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) + *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) + *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) + *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) + *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride * j; + if (img_n == 1) { + for (q = x - 1; q >= 0; --q) { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q = x - 1; q >= 0; --q) { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } + } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16 *)cur; + + for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, + stbi__uint32 image_data_len, int out_n, + int depth, int color, int interlaced) { + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, + a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + for (p = 0; p < 7; ++p) { + int xorig[] = {0, 4, 0, 2, 0, 1, 0}; + int yorig[] = {0, 0, 4, 0, 2, 0, 1}; + int xspc[] = {8, 8, 4, 4, 2, 2, 1}; + int yspc[] = {8, 8, 8, 4, 4, 2, 2}; + int i, j, x, y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, + y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j = 0; j < y; ++j) { + for (i = 0; i < x; ++i) { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, + a->out + (j * x + i) * out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - return 1; + return 1; } -static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) -{ - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc *final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - - // de-interlacing - final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - for (p=0; p < 7; ++p) { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; - int i,j,x,y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; - } - for (j=0; j < y; ++j) { - for (i=0; i < x; ++i) { - int out_y = j*yspc[p]+yorig[p]; - int out_x = i*xspc[p]+xorig[p]; - memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, - a->out + (j*x+i)*out_bytes, out_bytes); - } - } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - return 1; -} + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); -static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); - - if (out_n == 2) { - for (i=0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i=0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 *p = (stbi__uint16*) z->out; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], + int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16 *)z->out; - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) -{ - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; - - p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) return stbi__err("outofmem", "Out of memory"); - - // between here and free(out) below, exitting would leak - temp_out = p; - - if (pal_img_n == 3) { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p += 3; - } - } else { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p[3] = palette[n+3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, + int pal_img_n) { + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); - STBI_NOTUSED(len); + // between here and free(out) below, exitting would leak + temp_out = p; - return 1; + if (pal_img_n == 3) { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p += 3; + } + } else { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p[3] = palette[n + 3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; } static int stbi__unpremultiply_on_load = 0; static int stbi__de_iphone_flag = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) -{ - stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) -{ - stbi__de_iphone_flag = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { + stbi__de_iphone_flag = flag_true_if_should_convert; } -static void stbi__de_iphone(stbi__png *z) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - if (s->img_out_n == 3) { // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 3; +static void stbi__de_iphone(stbi__png *z) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i = 0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = (t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i=0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = ( t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } + } else { + // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a, b, c, d) \ + (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + \ + (unsigned)(d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) { + stbi_uc palette[1024], pal_img_n = 0; + stbi_uc has_trans = 0, tc[3] = {0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; + int first = 1, k, interlace = 0, color = 0, is_iphone = 0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) + return 0; + + if (scan == STBI__SCAN_type) + return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { + int comp, filter; + if (!first) + return stbi__err("multiple IHDR", "Corrupt PNG"); + first = 0; + if (c.length != 13) + return stbi__err("bad IHDR len", "Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + z->depth = stbi__get8(s); + if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && + z->depth != 16) + return stbi__err("1/2/4/8/16-bit only", + "PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); + if (color > 6) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3 && z->depth == 16) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3) + pal_img_n = 3; + else if (color & 1) + return stbi__err("bad ctype", "Corrupt PNG"); + comp = stbi__get8(s); + if (comp) + return stbi__err("bad comp method", "Corrupt PNG"); + filter = stbi__get8(s); + if (filter) + return stbi__err("bad filter method", "Corrupt PNG"); + interlace = stbi__get8(s); + if (interlace > 1) + return stbi__err("bad interlace method", "Corrupt PNG"); + if (!s->img_x || !s->img_y) + return stbi__err("0-pixel image", "Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) + return stbi__err("too large", "Image too large to decode"); + if (scan == STBI__SCAN_header) + return 1; } else { - // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) + return stbi__err("too large", "Corrupt PNG"); + // if SCAN_header, have to scan to see if we have a tRNS } - } -} - -#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + break; + } -static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) -{ - stbi_uc palette[1024], pal_img_n=0; - stbi_uc has_trans=0, tc[3]={0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; - int first=1,k,interlace=0, color=0, is_iphone=0; - stbi__context *s = z->s; - - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; - - if (!stbi__check_png_header(s)) return 0; - - if (scan == STBI__SCAN_type) return 1; - - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C','g','B','I'): - is_iphone = 1; - stbi__skip(s, c.length); - break; - case STBI__PNG_TYPE('I','H','D','R'): { - int comp,filter; - if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); - first = 0; - if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); - s->img_x = stbi__get32be(s); - s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); - comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); - filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); - interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); - if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); - if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); - if (scan == STBI__SCAN_header) return 1; - } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); - // if SCAN_header, have to scan to see if we have a tRNS - } - break; - } - - case STBI__PNG_TYPE('P','L','T','E'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); - pal_len = c.length / 3; - if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); - for (i=0; i < pal_len; ++i) { - palette[i*4+0] = stbi__get8(s); - palette[i*4+1] = stbi__get8(s); - palette[i*4+2] = stbi__get8(s); - palette[i*4+3] = 255; - } - break; - } - - case STBI__PNG_TYPE('t','R','N','S'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); - if (pal_img_n) { - if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } - if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); - if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); - pal_img_n = 4; - for (i=0; i < c.length; ++i) - palette[i*4+3] = stbi__get8(s); - } else { - if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); - if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); - has_trans = 1; - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } - } - break; - } - - case STBI__PNG_TYPE('I','D','A','T'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); - if (scan == STBI__SCAN_header) { s->img_n = pal_img_n; return 1; } - if ((int)(ioff + c.length) < (int)ioff) return 0; - if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc *p; - if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); - z->idata = p; - } - if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); - ioff += c.length; - break; - } - - case STBI__PNG_TYPE('I','E','N','D'): { - stbi__uint32 raw_len, bpl; - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) return 1; - if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); - // initial guess for decoded data size to avoid unnecessary reallocs - bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component - raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); - if (z->expanded == NULL) return 0; // zlib should set error - STBI_FREE(z->idata); z->idata = NULL; - if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n+1; - else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; - if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; - } - } - if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); - if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; - } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; - } - STBI_FREE(z->expanded); z->expanded = NULL; - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - return 1; - } - - default: - // if critical, fail - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if ((c.type & (1 << 29)) == 0) { - #ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); - #endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); - } - stbi__skip(s, c.length); - break; + case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256 * 3) + return stbi__err("invalid PLTE", "Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) + return stbi__err("invalid PLTE", "Corrupt PNG"); + for (i = 0; i < pal_len; ++i) { + palette[i * 4 + 0] = stbi__get8(s); + palette[i * 4 + 1] = stbi__get8(s); + palette[i * 4 + 2] = stbi__get8(s); + palette[i * 4 + 3] = 255; } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } -} + break; + } -static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) -{ - void *result=NULL; - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) return result; + case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) + return stbi__err("tRNS after IDAT", "Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { + s->img_n = 4; + return 1; + } + if (pal_len == 0) + return stbi__err("tRNS before PLTE", "Corrupt PNG"); + if (c.length > pal_len) + return stbi__err("bad tRNS len", "Corrupt PNG"); + pal_img_n = 4; + for (i = 0; i < c.length; ++i) + palette[i * 4 + 3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) + return stbi__err("tRNS with alpha", "Corrupt PNG"); + if (c.length != (stbi__uint32)s->img_n * 2) + return stbi__err("bad tRNS len", "Corrupt PNG"); + has_trans = 1; + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * + stbi__depth_scale_table[z->depth]; // non 8-bit images will + // be larger + } } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) *n = p->s->img_n; - } - STBI_FREE(p->out); p->out = NULL; - STBI_FREE(p->expanded); p->expanded = NULL; - STBI_FREE(p->idata); p->idata = NULL; - - return result; -} - -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi__png p; - p.s = s; - return stbi__do_png(&p, x,y,comp,req_comp, ri); -} - -static int stbi__png_test(stbi__context *s) -{ - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; -} + break; + } -static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) -{ - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind( p->s ); - return 0; - } - if (x) *x = p->s->img_x; - if (y) *y = p->s->img_y; - if (comp) *comp = p->s->img_n; - return 1; -} + case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) + return stbi__err("no PLTE", "Corrupt PNG"); + if (scan == STBI__SCAN_header) { + s->img_n = pal_img_n; + return 1; + } + if ((int)(ioff + c.length) < (int)ioff) + return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) + idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, + idata_limit); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata + ioff, c.length)) + return stbi__err("outofdata", "Corrupt PNG"); + ioff += c.length; + break; + } -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); -} + case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + stbi__uint32 raw_len, bpl; + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) + return 1; + if (z->idata == NULL) + return stbi__err("no IDAT", "Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag( + (char *)z->idata, ioff, raw_len, (int *)&raw_len, !is_iphone); + if (z->expanded == NULL) + return 0; // zlib should set error + STBI_FREE(z->idata); + z->idata = NULL; + if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || + has_trans) + s->img_out_n = s->img_n + 1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, + z->depth, color, interlace)) + return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) + return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) + return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) + s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); + z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } -static int stbi__png_is16(stbi__context *s) -{ - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; + default: + // if critical, fail + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { +#ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); +#endif + return stbi__err(invalid_chunk, + "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, + stbi__result_info *ri) { + void *result = NULL; + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", + "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) + return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) + *n = p->s->img_n; + } + STBI_FREE(p->out); + p->out = NULL; + STBI_FREE(p->expanded); + p->expanded = NULL; + STBI_FREE(p->idata); + p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi__png p; + p.s = s; + return stbi__do_png(&p, x, y, comp, req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) { + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) { + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind(p->s); + return 0; + } + if (x) + *x = p->s->img_x; + if (y) + *y = p->s->img_y; + if (comp) + *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) { + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) { + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context *s) -{ - int r; - int sz; - if (stbi__get8(s) != 'B') return 0; - if (stbi__get8(s) != 'M') return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; -} - -static int stbi__bmp_test(stbi__context *s) -{ - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test_raw(stbi__context *s) { + int r; + int sz; + if (stbi__get8(s) != 'B') + return 0; + if (stbi__get8(s) != 'M') + return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) { + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } - // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) -{ - int n=0; - if (z == 0) return -1; - if (z >= 0x10000) { n += 16; z >>= 16; } - if (z >= 0x00100) { n += 8; z >>= 8; } - if (z >= 0x00010) { n += 4; z >>= 4; } - if (z >= 0x00004) { n += 2; z >>= 2; } - if (z >= 0x00002) { n += 1;/* >>= 1;*/ } - return n; +static int stbi__high_bit(unsigned int z) { + int n = 0; + if (z == 0) + return -1; + if (z >= 0x10000) { + n += 16; + z >>= 16; + } + if (z >= 0x00100) { + n += 8; + z >>= 8; + } + if (z >= 0x00010) { + n += 4; + z >>= 4; + } + if (z >= 0x00004) { + n += 2; + z >>= 2; + } + if (z >= 0x00002) { + n += 1; /* >>= 1;*/ + } + return n; } -static int stbi__bitcount(unsigned int a) -{ - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) { + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) -{ - static unsigned int mul_table[9] = { +static int stbi__shiftsigned(unsigned int v, int shift, int bits) { + static unsigned int mul_table[9] = { 0, - 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, - 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0,0,1,0,2,4,6,0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8-bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; -} - -typedef struct -{ - int bpp, offset, hsz; - unsigned int mr,mg,mb,ma, all_a; - int extra_read; + 0xff /*0b11111111*/, + 0x55 /*0b01010101*/, + 0x49 /*0b01001001*/, + 0x11 /*0b00010001*/, + 0x21 /*0b00100001*/, + 0x41 /*0b01000001*/, + 0x81 /*0b10000001*/, + 0x01 /*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0, 0, 1, 0, 2, 4, 6, 0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8 - bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct { + int bpp, offset, hsz; + unsigned int mr, mg, mb, ma, all_a; + int extra_read; } stbi__bmp_data; -static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) -{ - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; - - if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); - - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - int i; - if (hsz != 108 && hsz != 124) +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) { + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') + return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) + return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) + return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) + return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) + return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha + // channel but it was all 0 + } else { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - stbi__get32le(s); // discard color space - for (i=0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); } - } - return (void *) 1; -} - - -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - unsigned int mr=0,mg=0,mb=0,ma=0, all_a; - stbi_uc pal[256][4]; - int psize=0,i,j,width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set - - flip_vertically = ((int) s->img_y) > 0; - s->img_y = abs((int) s->img_y); - - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; - - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - STBI_ASSERT(info.offset == s->callback_already_read + (int) (s->img_buffer - s->img_buffer_original)); - if (info.offset != s->callback_already_read + (s->img_buffer - s->buffer_start)) { - return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + stbi__get32le(s); // discard color space + for (i = 0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved } - } - - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert - - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); - - out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z=0; - if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } - for (i=0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) stbi__get8(s); - pal[i][3] = 255; + } + } + return (void *)1; +} + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; + stbi_uc pal[256][4]; + int psize = 0, i, j, width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int)s->img_y) > 0; + s->img_y = abs((int)s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + STBI_ASSERT(info.offset == + s->callback_already_read + + (int)(s->img_buffer - s->img_buffer_original)); + if (info.offset != + s->callback_already_read + (s->img_buffer - s->buffer_start)) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z = 0; + if (psize == 0 || psize > 256) { + STBI_FREE(out); + return stbi__errpuc("invalid", "Corrupt BMP"); + } + for (i = 0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) + stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - + psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) + width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) + width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) + width = s->img_x; + else { + STBI_FREE(out); + return stbi__errpuc("bad bpp", "Corrupt BMP"); + } + pad = (-width) & 3; + if (info.bpp == 1) { + for (j = 0; j < (int)s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i = 0; i < (int)s->img_x; ++i) { + int color = (v >> bit_offset) & 0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + if ((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) width = s->img_x; - else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } - pad = (-width)&3; - if (info.bpp == 1) { - for (j=0; j < (int) s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i=0; i < (int) s->img_x; ++i) { - int color = (v>>bit_offset)&0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - if((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j=0; j < (int) s->img_y; ++j) { - for (i=0; i < (int) s->img_x; i += 2) { - int v=stbi__get8(s),v2=0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - } - stbi__skip(s, pad); - } + } else { + for (j = 0; j < (int)s->img_y; ++j) { + for (i = 0; i < (int)s->img_x; i += 2) { + int v = stbi__get8(s), v2 = 0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + } + stbi__skip(s, pad); } - } else { - int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; - int z = 0; - int easy=0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) width = 3 * s->img_x; - else if (info.bpp == 16) width = 2*s->img_x; - else /* bpp = 32 and pad = 0 */ width=0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; + } + } else { + int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, + bcount = 0, acount = 0; + int z = 0; + int easy = 0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) + width = 3 * s->img_x; + else if (info.bpp == 16) + width = 2 * s->img_x; + else /* bpp = 32 and pad = 0 */ + width = 0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - if (!easy) { - if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr) - 7; + rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg) - 7; + gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb) - 7; + bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma) - 7; + acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - for (j=0; j < (int) s->img_y; ++j) { - if (easy) { - for (i=0; i < (int) s->img_x; ++i) { - unsigned char a; - out[z+2] = stbi__get8(s); - out[z+1] = stbi__get8(s); - out[z+0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i=0; i < (int) s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) out[z++] = STBI__BYTECAST(a); - } - } - stbi__skip(s, pad); + } + for (j = 0; j < (int)s->img_y; ++j) { + if (easy) { + for (i = 0; i < (int)s->img_x; ++i) { + unsigned char a; + out[z + 2] = stbi__get8(s); + out[z + 1] = stbi__get8(s); + out[z + 0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) + out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i = 0; i < (int)s->img_x; ++i) { + stbi__uint32 v = + (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) + out[z++] = STBI__BYTECAST(a); + } } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j=0; j < (int) s->img_y>>1; ++j) { - stbi_uc *p1 = out + j *s->img_x*target; - stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; - for (i=0; i < (int) s->img_x*target; ++i) { - t = p1[i]; p1[i] = p2[i]; p2[i] = t; - } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j = 0; j < (int)s->img_y >> 1; ++j) { + stbi_uc *p1 = out + j * s->img_x * target; + stbi_uc *p2 = out + (s->img_y - 1 - j) * s->img_x * target; + for (i = 0; i < (int)s->img_x * target; ++i) { + t = p1[i]; + p1[i] = p2[i]; + p2[i] = t; } - } + } + } - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; - return out; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + return out; } #endif @@ -5555,592 +6271,625 @@ static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) -{ - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) *is_rgb16 = 0; - switch(bits_per_pixel) { - case 8: return STBI_grey; - case 16: if(is_grey) return STBI_grey_alpha; - // fallthrough - case 15: if(is_rgb16) *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: return bits_per_pixel/8; - default: return 0; - } -} - -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) -{ - int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; - int sz, tga_colormap_type; - stbi__get8(s); // discard Offset - tga_colormap_type = stbi__get8(s); // colormap type - if( tga_colormap_type > 1 ) { - stbi__rewind(s); - return 0; // only RGB or indexed allowed - } - tga_image_type = stbi__get8(s); // image type - if ( tga_colormap_type == 1 ) { // colormapped (paletted) image - if (tga_image_type != 1 && tga_image_type != 9) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip image x and y origin - tga_colormap_bpp = sz; - } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { - stbi__rewind(s); - return 0; // only RGB or grey allowed, +/- RLE - } - stbi__skip(s,9); // skip colormap specification and image x/y origin - tga_colormap_bpp = 0; - } - tga_w = stbi__get16le(s); - if( tga_w < 1 ) { - stbi__rewind(s); - return 0; // test width +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int *is_rgb16) { + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) + *is_rgb16 = 0; + switch (bits_per_pixel) { + case 8: + return STBI_grey; + case 16: + if (is_grey) + return STBI_grey_alpha; + // fallthrough + case 15: + if (is_rgb16) + *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: + return bits_per_pixel / 8; + default: + return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) { + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, + tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if (tga_colormap_type > 1) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; } - tga_h = stbi__get16le(s); - if( tga_h < 1 ) { - stbi__rewind(s); - return 0; // test height + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__rewind(s); + return 0; } - tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits - if (tga_colormap_bpp != 0) { - if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { - // when using a colormap, tga_bits_per_pixel is the size of the indexes - // I don't think anything but 8 or 16bit indexes makes sense - stbi__rewind(s); - return 0; - } - tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); - } else { - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + stbi__skip(s, 4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ((tga_image_type != 2) && (tga_image_type != 3) && + (tga_image_type != 10) && (tga_image_type != 11)) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE } - if(!tga_comp) { + stbi__skip(s, 9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if (tga_w < 1) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if (tga_h < 1) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); return 0; } - if (x) *x = tga_w; - if (y) *y = tga_h; - if (comp) *comp = tga_comp; - return 1; // seems to have passed everything -} - -static int stbi__tga_test(stbi__context *s) -{ - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if ( tga_color_type == 1 ) { // colormapped (paletted) image - if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - stbi__skip(s,4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s,9); // skip colormap specification and image x/y origin - } - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp( + tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), + NULL); + } + if (!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) + *x = tga_w; + if (y) + *y = tga_h; + if (comp) + *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) { + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if (tga_color_type > 1) + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image + if (sz != 1 && sz != 9) + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin + } + if (stbi__get16le(s) < 1) + goto errorEnd; // test width + if (stbi__get16le(s) < 1) + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) + goto errorEnd; // for colormapped images, bpp is size of an index + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead + // of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) -{ - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255)/31); - out[1] = (stbi_uc)((g * 255)/31); - out[2] = (stbi_uc)((b * 255)/31); - - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. -} - -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16=0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char *tga_data; - unsigned char *tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO - - if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // do a tiny bit of precessing - if ( tga_image_type >= 8 ) - { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); - - // If I'm paletted, then I'll use the number of bits from the palette - if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - - if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - - // tga info - *x = tga_width; - *y = tga_height; - if (comp) *comp = tga_comp; - - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); - - tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); - - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset ); - - if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { - for (i=0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height -i - 1 : i; - stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if ( tga_indexed) - { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start ); - // load the palette - tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); - } - if (tga_rgb16) { - stbi_uc *pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i=0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc *out) { + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be + // swapped later + out[0] = (stbi_uc)((r * 255) / 31); + out[1] = (stbi_uc)((g * 255) / 31); + out[2] = (stbi_uc)((b * 255) / 31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16 = 0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused + // (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // do a tiny bit of precessing + if (tga_image_type >= 8) { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if (tga_indexed) + tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), + &tga_rgb16); + + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have + // ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) + *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = + (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) + return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset); + + if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { + for (i = 0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height - i - 1 : i; + stbi_uc *tga_row = tga_data + row * tga_width * tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if (tga_indexed) { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // load the data - for (i=0; i < tga_width * tga_height; ++i) - { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if ( tga_is_RLE ) - { - if ( RLE_count == 0 ) - { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if ( !RLE_repeating ) - { - read_next_pixel = 1; - } - } else - { - read_next_pixel = 1; - } - // OK, if I need to read a pixel, do it now - if ( read_next_pixel ) - { - // load however much data we did have - if ( tga_indexed ) - { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if ( pal_idx >= tga_palette_len ) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx+j]; - } - } else if(tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel - - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i*tga_comp+j] = raw_data[j]; - // in case we're in RLE mode, keep counting down - --RLE_count; + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start); + // load the palette + tga_palette = + (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); } - // do I need to invert the image? - if ( tga_inverted ) - { - for (j = 0; j*2 < tga_height; ++j) - { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) - { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } - } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i = 0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // clear my palette, if I had one - if ( tga_palette != NULL ) - { - STBI_FREE( tga_palette ); + } + // load the data + for (i = 0; i < tga_width * tga_height; ++i) { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if (tga_is_RLE) { + if (RLE_count == 0) { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if (!RLE_repeating) { + read_next_pixel = 1; + } + } else { + read_next_pixel = 1; } - } + // OK, if I need to read a pixel, do it now + if (read_next_pixel) { + // load however much data we did have + if (tga_indexed) { + // read in index, then perform the lookup + int pal_idx = + (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if (pal_idx >= tga_palette_len) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx + j]; + } + } else if (tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) - { - unsigned char* tga_pixel = tga_data; - for (i=0; i < tga_width * tga_height; ++i) - { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i * tga_comp + j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if (tga_inverted) { + for (j = 0; j * 2 < tga_height; ++j) { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } } - } + } + // clear my palette, if I had one + if (tga_palette != NULL) { + STBI_FREE(tga_palette); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) { + unsigned char *tga_pixel = tga_data; + for (i = 0; i < tga_width * tga_height; ++i) { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, + tga_height); - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = - tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = + tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; } #endif // ************************************************************************************************* -// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, +// tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s) -{ - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; -} - -static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) -{ - int count, nleft, len; - - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; - } +static int stbi__psd_test(stbi__context *s) { + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) { + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) + return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; } - } - - return 1; -} + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) + return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w, h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", + "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", + "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for + // other modes.) + stbi__skip(s, stbi__get32be(s)); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s)); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s)); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", + "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *)stbi__malloc(4 * w * h); + + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w * h; + + // Initialize the data to zero. + // memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes + // literally. Else if n is between -127 and -1 inclusive, copy the next + // byte -n+1 times. Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row + // in the data, which we're going to just skip. + stbi__skip(s, h * channelCount * 2); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out + channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w,h; - stbi_uc *out; - STBI_NOTUSED(ri); - - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); - - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - - // Skip 6 reserved bytes. - stbi__skip(s, 6 ); - - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); - - if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s,stbi__get32be(s) ); - - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s) ); - - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s) ); - - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); - - // Create the destination image. - - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *) stbi__malloc(4 * w*h); - - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w*h; - - // Initialize the data to zero. - //memset( out, 0, pixelCount * 4 ); - - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop - - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2 ); - - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc *p; - - p = out+channel; - if (channel >= channelCount) { - // Fill this channel with default data. + } else { + // We're at the raw image data. It's each channel in order (Red, Green, + // Blue, Alpha, ...) where each channel consists of an 8-bit (or 16-bit) + // value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out + channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16)stbi__get16be(s); + } else { + stbi_uc *p = out + channel; + if (bitdepth == 16) { // input bpc for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); - } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } - } + *p = (stbi_uc)(stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } } - - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc *p = out+channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } - } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16) stbi__get16be(s); - } else { - stbi_uc *p = out+channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc) (stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } - } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i = 0; i < w * h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *)out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); + pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); + pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); + } } - } - - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i=0; i < w*h; ++i) { - stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); - pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); - pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); - } - } - } else { - for (i=0; i < w*h; ++i) { - unsigned char *pixel = out + 4*i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); - pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); - pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); - } - } + } else { + for (i = 0; i < w * h; ++i) { + unsigned char *pixel = out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); + pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); + pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); + } } - } - - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - - if (comp) *comp = 4; - *y = h; - *x = w; - - return out; + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, + w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + if (comp) + *comp = 4; + *y = h; + *x = w; + + return out; } #endif @@ -6152,215 +6901,222 @@ static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context *s,const char *str) -{ - int i; - for (i=0; i<4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s, const char *str) { + int i; + for (i = 0; i < 4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context *s) -{ - int i; +static int stbi__pic_test_core(stbi__context *s) { + int i; - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) + return 0; - for(i=0;i<84;++i) - stbi__get8(s); + for (i = 0; i < 84; ++i) + stbi__get8(s); - if (!stbi__pic_is4(s,"PICT")) - return 0; + if (!stbi__pic_is4(s, "PICT")) + return 0; - return 1; + return 1; } -typedef struct -{ - stbi_uc size,type,channel; +typedef struct { + stbi_uc size, type, channel; } stbi__pic_packet; -static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) -{ - int mask=0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) { + int mask = 0x80, i; - for (i=0; i<4; ++i, mask>>=1) { - if (channel & mask) { - if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); - dest[i]=stbi__get8(s); - } - } + for (i = 0; i < 4; ++i, mask >>= 1) { + if (channel & mask) { + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "PIC file too short"); + dest[i] = stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) -{ - int mask=0x80,i; +static void stbi__copyval(int channel, stbi_uc *dest, const stbi_uc *src) { + int mask = 0x80, i; - for (i=0;i<4; ++i, mask>>=1) - if (channel&mask) - dest[i]=src[i]; + for (i = 0; i < 4; ++i, mask >>= 1) + if (channel & mask) + dest[i] = src[i]; } -static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) -{ - int act_comp=0,num_packets=0,y,chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s, int width, int height, + int *comp, stbi_uc *result) { + int act_comp = 0, num_packets = 0, y, chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data - // for the same channel in multiple packets. - do { - stbi__pic_packet *packet; + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return stbi__errpuc("bad format","too many packets"); + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return stbi__errpuc("bad format", "too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); - if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (reading packets)"); + if (packet->size != 8) + return stbi__errpuc("bad format", "packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for(y=0; ytype) { - default: - return stbi__errpuc("bad format","packet has bad compression type"); + switch (packet->type) { + default: + return stbi__errpuc("bad format", "packet has bad compression type"); - case 0: {//uncompressed - int x; + case 0: { // uncompressed + int x; - for(x=0;xchannel,dest)) - return 0; - break; - } + for (x = 0; x < width; ++x, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + break; + } - case 1://Pure RLE - { - int left=width, i; - - while (left>0) { - stbi_uc count,value[4]; - - count=stbi__get8(s); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - - if (count > left) - count = (stbi_uc) left; - - if (!stbi__readval(s,packet->channel,value)) return 0; - - for(i=0; ichannel,dest,value); - left -= count; - } - } - break; - - case 2: {//Mixed RLE - int left=width; - while (left>0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); - - if (count >= 128) { // Repeated - stbi_uc value[4]; - - if (count==128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file","scanline overrun"); - - if (!stbi__readval(s,packet->channel,value)) - return 0; - - for(i=0;ichannel,dest,value); - } else { // Raw - ++count; - if (count>left) return stbi__errpuc("bad file","scanline overrun"); - - for(i=0;ichannel,dest)) - return 0; - } - left-=count; - } - break; - } - } + case 1: // Pure RLE + { + int left = width, i; + + while (left > 0) { + stbi_uc count, value[4]; + + count = stbi__get8(s); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pure read count)"); + + if (count > left) + count = (stbi_uc)left; + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + left -= count; + } + } break; + + case 2: { // Mixed RLE + int left = width; + while (left > 0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", + "file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count == 128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + } else { // Raw + ++count; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + for (i = 0; i < count; ++i, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + } + left -= count; + } + break; + } } - } + } + } - return result; + return result; } -static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) -{ - stbi_uc *result; - int i, x,y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s, int *px, int *py, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *result; + int i, x, y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) comp = &internal_comp; + if (!comp) + comp = &internal_comp; - for (i=0; i<92; ++i) - stbi__get8(s); + for (i = 0; i < 92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) + return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); //skip `ratio' - stbi__get16be(s); //skip `fields' - stbi__get16be(s); //skip `pad' + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); - memset(result, 0xff, x*y*4); + // intermediate buffer is RGBA + result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); + memset(result, 0xff, x * y * 4); - if (!stbi__pic_load_core(s,x,y,comp, result)) { - STBI_FREE(result); - result=0; - } - *px = x; - *py = y; - if (req_comp == 0) req_comp = *comp; - result=stbi__convert_format(result,4,req_comp,x,y); + if (!stbi__pic_load_core(s, x, y, comp, result)) { + STBI_FREE(result); + result = 0; + } + *px = x; + *py = y; + if (req_comp == 0) + req_comp = *comp; + result = stbi__convert_format(result, 4, req_comp, x, y); - return result; + return result; } -static int stbi__pic_test(stbi__context *s) -{ - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) { + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6368,514 +7124,539 @@ static int stbi__pic_test(stbi__context *s) // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct -{ - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct { + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct -{ - int w,h; - stbi_uc *out; // output buffer (always 4 components) - stbi_uc *background; // The current "background" as far as a gif is concerned - stbi_uc *history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc *color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct { + int w, h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context *s) -{ - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') return 0; - if (stbi__get8(s) != 'a') return 0; - return 1; -} - -static int stbi__gif_test(stbi__context *s) -{ - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; -} - -static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) -{ - int i; - for (i=0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } -} +static int stbi__gif_test_raw(stbi__context *s) { + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') + return 0; + if (stbi__get8(s) != 'a') + return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) { + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], + int num_entries, int transp) { + int i; + for (i = 0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, + int is_info) { + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') + return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') + return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + + if (comp != 0) + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the + // comments + + if (is_info) + return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) { + stbi__gif *g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind(s); + return 0; + } + if (x) + *x = g->w; + if (y) + *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) { + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) + return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) { + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) + return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc)init_code; + g->codes[init_code].suffix = (stbi_uc)init_code; + } + + // support no starting clear code + avail = clear + 2; + oldcode = -1; + + len = 0; + for (;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32)stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s, len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } -static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) -{ - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - version = stbi__get8(s); - if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + p->prefix = (stbi__int16)oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; + stbi__out_gif_code(g, (stbi__uint16)code); - if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } - if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image +// doesn't support it two back is the image from two frames ago, used for a very +// specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, + int req_comp, stbi_uc *two_back) { + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour + // (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp, 0)) + return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *)stbi__malloc(4 * pcount); + g->background = (stbi_uc *)stbi__malloc(4 * pcount); + g->history = (stbi_uc *)stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); - if (is_info) return 1; + // image is treated as "transparent" at the start - ie, nothing overwrites + // the current background; background colour is only used for pixels that + // are not rendered first frame, after that "background" color refers to the + // color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, + 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, + pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the + // old background + } - if (g->flags & 0x80) - stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } - return 1; -} + // background is what out is after the undoing of the previou frame; + memcpy(g->background, g->out, 4 * g->w * g->h); + } + + // clear my history; + memset(g->history, 0x00, + g->w * g->h); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; -static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind( s ); - return 0; - } - if (x) *x = g->w; - if (y) *y = g->h; - STBI_FREE(g); - return 1; -} + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; -static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) -{ - stbi_uc *p, *c; - int idx; - - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); - - if (g->cur_y >= g->max_y) return; - - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; - - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; - - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + g->lflags = stbi__get8(s); - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; } - } -} -static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) -{ - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw *p; - - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc) init_code; - g->codes[init_code].suffix = (stbi_uc) init_code; - } - - // support no starting clear code - avail = clear+2; - oldcode = -1; - - len = 0; - for(;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; - } - --len; - bits |= (stbi__int32) stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s,len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), + g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *)g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *)g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } + o = stbi__process_gif_raster(s, g); + if (!o) + return NULL; - p->prefix = (stbi__int16) oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = + 255; // just in case it was made transparent, undo that; It will + // be reset next frame if need be; + memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); + } + } + } - stbi__out_gif_code(g, (stbi__uint16) code); + return o; + } - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = + 10 * stbi__get16le( + s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } - } -} - -// this function is designed to support animated gifs, although stb_image doesn't support it -// two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) -{ - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); - - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *) stbi__malloc(4 * pcount); - g->background = (stbi_uc *) stbi__malloc(4 * pcount); - g->history = (stbi_uc *) stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); - - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; - - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); } + break; + } - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); - } - } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); - } - } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers - // background is what out is after the undoing of the previou frame; - memcpy( g->background, g->out, 4 * g->w * g->h ); - } - - // clear my history; - memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame - - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { - stbi__int32 x, y, w, h; - stbi_uc *o; - - x = stbi__get16le(s); - y = stbi__get16le(s); - w = stbi__get16le(s); - h = stbi__get16le(s); - if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); - - g->line_size = g->w * 4; - g->start_x = x * 4; - g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; - - // if the width of the specified rectangle is 0, that means - // we may not see *any* pixels or the image is malformed; - // to make sure this is caught, move the current y down to - // max_y (which is what out_gif_code checks). - if (w == 0) - g->cur_y = g->max_y; - - g->lflags = stbi__get8(s); - - if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; - } else { - g->step = g->line_size; - g->parse = 0; - } + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp) { + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } - if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *) g->lpal; - } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *) g->pal; - } else - return stbi__errpuc("missing color table", "Corrupt GIF"); - - o = stbi__process_gif_raster(s, g); - if (!o) return NULL; - - // if this was the first frame, - pcount = g->w * g->h; - if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); - } - } - } + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = + (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + if (NULL == tmp) { + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + return stbi__errpuc("outofmem", "Out of memory"); + } else { + out = (stbi_uc *)tmp; + out_size = layers * stride; + } + + if (delays) { + *delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, + sizeof(int) * layers); + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc *)stbi__malloc(layers * stride); + out_size = layers * stride; + if (delays) { + *delays = (int *)stbi__malloc(layers * sizeof(int)); + delays_size = layers * sizeof(int); + } + } + memcpy(out + ((layers - 1) * stride), u, stride); + if (layers >= 2) { + two_back = out - 2 * stride; + } - return o; - } - - case 0x21: // Comment Extension. - { - int len; - int ext = stbi__get8(s); - if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } - } - while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); - } - break; - } + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); - case 0x3B: // gif stream termination code - return (stbi_uc *) s; // using '1' causes warning on some compilers + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - default: - return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } -} + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc *u = 0; - stbi_uc *out = 0; - stbi_uc *two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; - - if (out) { - void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); - if (NULL == tmp) { - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); - return stbi__errpuc("outofmem", "Out of memory"); - } - else { - out = (stbi_uc*) tmp; - out_size = layers * stride; - } - - if (delays) { - *delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc*)stbi__malloc( layers * stride ); - out_size = layers * stride; - if (delays) { - *delays = (int*) stbi__malloc( layers * sizeof(int) ); - delays_size = layers * sizeof(int); - } - } - memcpy( out + ((layers - 1) * stride), u, stride ); - if (layers >= 2) { - two_back = out - 2 * stride; - } +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); - if (delays) { - (*delays)[layers - 1U] = g.delay; - } - } - } while (u != 0); + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + return u; } -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); - - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } - - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); - - return u; -} - -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) -{ - return stbi__gif_info_raw(s,x,y,comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) { + return stbi__gif_info_raw(s, x, y, comp); } #endif @@ -6883,396 +7664,434 @@ static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context *s, const char *signature) -{ - int i; - for (i=0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; -} - -static int stbi__hdr_test(stbi__context* s) -{ - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if(!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; -} - -#define STBI__HDR_BUFLEN 1024 -static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) -{ - int len=0; - char c = '\0'; - - c = (char) stbi__get8(z); - - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN-1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) { + int i; + for (i = 0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context *s) { + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if (!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) { + int len = 0; + char c = '\0'; + + c = (char)stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN - 1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char)stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) { + if (input[3] != 0) { + float f1; + // Exponent + f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) + output[1] = 1; + if (req_comp == 4) + output[3] = 1; + } else { + switch (req_comp) { + case 4: + output[3] = 1; /* fallthrough */ + case 3: + output[0] = output[1] = output[2] = 0; + break; + case 2: + output[1] = 1; /* fallthrough */ + case 1: + output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1, c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s, buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && + strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) + return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int)strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) + *comp = 3; + if (req_comp == 0) + req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = + (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if (width < 8 || width >= 32768) { + // Read flat data + for (j = 0; j < height; ++j) { + for (i = 0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, + req_comp); } - c = (char) stbi__get8(z); - } - - buffer[len] = 0; - return buffer; -} + } + } else { + // Read RLE-encoded data + scanline = NULL; -static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) -{ - if ( input[3] != 0 ) { - float f1; - // Exponent - f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a + // decoded pixel (note this can't be a valid pixel--one of RGB must be + // >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc)c1; + rgbe[1] = (stbi_uc)c2; + rgbe[2] = (stbi_uc)len; + rgbe[3] = (stbi_uc)stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense } - if (req_comp == 2) output[1] = 1; - if (req_comp == 4) output[3] = 1; - } else { - switch (req_comp) { - case 4: output[3] = 1; /* fallthrough */ - case 3: output[0] = output[1] = output[2] = 0; - break; - case 2: output[1] = 1; /* fallthrough */ - case 1: output[0] = 0; - break; + len <<= 8; + len |= stbi__get8(s); + if (len != width) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - } -} - -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int width, height; - stbi_uc *scanline; - float *hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1,c2, z; - const char *headerToken; - STBI_NOTUSED(ri); - - // Check identifier - headerToken = stbi__hdr_gettoken(s,buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); - - // Parse header - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); - - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int) strtol(token, NULL, 10); - - if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - - *x = width; - *y = height; - - if (comp) *comp = 3; - if (req_comp == 0) req_comp = 3; - - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); - - // Read data - hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); - - // Load image data - // image data is stored as some number of sca - if ( width < 8 || width >= 32768) { - // Read flat data - for (j=0; j < height; ++j) { - for (i=0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } + if (scanline == NULL) { + scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } } - } else { - // Read RLE-encoded data - scanline = NULL; - - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc) c1; - rgbe[1] = (stbi_uc) c2; - rgbe[2] = (stbi_uc) len; - rgbe[3] = (stbi_uc) stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - if (scanline == NULL) { - scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - for (i=0; i < width; ++i) - stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } } - if (scanline) - STBI_FREE(scanline); - } - - return hdr_data; -} - -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int dummy; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (stbi__hdr_test(s) == 0) { - stbi__rewind( s ); - return 0; - } - - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) { - stbi__rewind( s ); - return 0; - } - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *y = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *x = (int) strtol(token, NULL, 10); - *comp = 3; - return 1; + for (i = 0; i < width; ++i) + stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, + scanline + i * 4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind(s); + return 0; + } + + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) { + stbi__rewind(s); + return 0; + } + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *y = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *x = (int)strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) -{ - void *p; - stbi__bmp_data info; - - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - stbi__rewind( s ); - if (p == NULL) - return 0; - if (x) *x = s->img_x; - if (y) *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) { + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + stbi__rewind(s); + if (p == NULL) + return 0; + if (x) + *x = s->img_x; + if (y) + *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) -{ - int channelCount, dummy, depth; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind( s ); - return 0; - } - *comp = 4; - return 1; -} - -static int stbi__psd_is16(stbi__context *s) -{ - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - (void) stbi__get32be(s); - (void) stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) { + int channelCount, dummy, depth; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind(s); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) { + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + (void)stbi__get32be(s); + (void)stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind(s); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) -{ - int act_comp=0,num_packets=0,chained,dummy; - stbi__pic_packet packets[10]; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { - stbi__rewind(s); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) { + int act_comp = 0, num_packets = 0, chained, dummy; + stbi__pic_packet packets[10]; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) return 0; - } - stbi__skip(s, 88); + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind( s); + if (stbi__at_eof(s)) { + stbi__rewind(s); return 0; - } - if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind( s ); + } + if (packet->size != 8) { + stbi__rewind(s); return 0; - } - - stbi__skip(s, 8); - - do { - stbi__pic_packet *packet; - - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return 0; - - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; - - if (stbi__at_eof(s)) { - stbi__rewind( s ); - return 0; - } - if (packet->size != 8) { - stbi__rewind( s ); - return 0; - } - } while (chained); + } + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); + *comp = (act_comp & 0x10 ? 4 : 3); - return 1; + return 1; } #endif @@ -7290,257 +8109,266 @@ static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s) -{ - char p, t; - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) { + char p, t; + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + return 1; } -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - STBI_NOTUSED(ri); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + STBI_NOTUSED(ri); - if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) - return 0; + if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) + return 0; - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; - if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "PNM too large"); + if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "PNM too large"); - out = (stbi_uc *) stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - stbi__getn(s, out, s->img_n * s->img_x * s->img_y); + out = (stbi_uc *)stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + stbi__getn(s, out, s->img_n * s->img_x * s->img_y); - if (req_comp && req_comp != s->img_n) { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - return out; + if (req_comp && req_comp != s->img_n) { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + return out; } -static int stbi__pnm_isspace(char c) -{ - return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +static int stbi__pnm_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; } -static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) -{ - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char) stbi__get8(s); +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) { + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char)stbi__get8(s); - if (stbi__at_eof(s) || *c != '#') - break; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') + *c = (char)stbi__get8(s); + } } -static int stbi__pnm_isdigit(char c) -{ - return c >= '0' && c <= '9'; +static int stbi__pnm_isdigit(char c) { + return c >= '0' && c <= '9'; } -static int stbi__pnm_getinteger(stbi__context *s, char *c) -{ - int value = 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) { + int value = 0; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value*10 + (*c - '0'); - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value * 10 + (*c - '0'); + *c = (char)stbi__get8(s); + } - return value; + return value; } -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) -{ - int maxv, dummy; - char c, p, t; +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) { + int maxv, dummy; + char c, p, t; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; - stbi__rewind(s); + stbi__rewind(s); - // Get identifier - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } + // Get identifier + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = + (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - c = (char) stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); + c = (char)stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width - stbi__pnm_skip_whitespace(s, &c); + *x = stbi__pnm_getinteger(s, &c); // read width + stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height - stbi__pnm_skip_whitespace(s, &c); + *y = stbi__pnm_getinteger(s, &c); // read height + stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value + maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 255) - return stbi__err("max value > 255", "PPM image not 8-bit"); - else - return 1; + if (maxv > 255) + return stbi__err("max value > 255", "PPM image not 8-bit"); + else + return 1; } #endif -static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) -{ - #ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) return 1; - #endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) { +#ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) + return 1; +#endif - // test tga last because it's a crappy test! - #ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; - #endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +// test tga last because it's a crappy test! +#ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; +#endif + return stbi__err("unknown image type", + "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context *s) -{ - #ifndef STBI_NO_PNG - if (stbi__png_is16(s)) return 1; - #endif +static int stbi__is_16_main(stbi__context *s) { +#ifndef STBI_NO_PNG + if (stbi__png_is16(s)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) + return 1; +#endif - return 0; + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_info_from_file(f, x, y, comp); - fclose(f); - return result; -} - -STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s,x,y,comp); - fseek(f,pos,SEEK_SET); - return r; -} - -STBIDEF int stbi_is_16_bit(char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_is_16_bit_from_file(f); - fclose(f); - return result; -} - -STBIDEF int stbi_is_16_bit_from_file(FILE *f) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f,pos,SEEK_SET); - return r; +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s, x, y, comp); + fseek(f, pos, SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f, pos, SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, + int *x, int *y, int *comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, + void *user) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION /* revision history: - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs - 2.19 (2018-02-11) fix warning - 2.18 (2018-01-30) fix warnings - 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and + platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix + warnings 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug 1-bit BMP *_is_16_bit api avoid warnings @@ -7555,13 +8383,11 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user warning fixes; disable run-time SSE detection on gcc; uniform handling of optional "return" values; thread-safe initialization of zlib tables - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) allocate large structures on the stack - remove white matting for transparent PSD - fix reported channel count for PNG & BMP - re-enable SSE2 in non-gcc 64-bit + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet + JPGs 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now 2.12 + (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 (2016-04-02) + allocate large structures on the stack remove white matting for transparent + PSD fix reported channel count for PNG & BMP re-enable SSE2 in non-gcc 64-bit support RGB-formatted JPEG read 16-bit PNGs (only as 8-bit) 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED @@ -7569,11 +8395,9 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 16-bit-per-pixel TGA (not bit-per-component) info() for TGA could break due to .hdr handling info() for BMP to shares code instead of sloppy parse - can use STBI_REALLOC_SIZED if allocator doesn't support realloc - code cleanup - 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA - 2.07 (2015-09-13) fix compiler warnings - partial animated GIF support + can use STBI_REALLOC_SIZED if allocator doesn't support + realloc code cleanup 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD + as RGBA 2.07 (2015-09-13) fix compiler warnings partial animated GIF support limited 16-bpc PSD support #ifdef unused functions bug with < 92 byte PIC,PNM,HDR,TGA @@ -7584,23 +8408,18 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user stbi_set_flip_vertically_on_load (nguillemot) fix NEON support; fix mingw support 2.02 (2015-01-19) fix incorrect assert, fix warning - 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 - 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG - 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) - progressive JPEG (stb) - PGM/PPM support (Ken Miller) - STBI_MALLOC,STBI_REALLOC,STBI_FREE + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit + without -msse2 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG 2.00 + (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) progressive + JPEG (stb) PGM/PPM support (Ken Miller) STBI_MALLOC,STBI_REALLOC,STBI_FREE GIF bugfix -- seemingly never worked STBI_NO_*, STBI_ONLY_* 1.48 (2014-12-14) fix incorrectly-named assert() - 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) - optimize PNG (ryg) - fix bug in interlaced PNG with user-specified channel count (stb) - 1.46 (2014-08-26) - fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG - 1.45 (2014-08-16) - fix MSVC-ARM internal compiler error by wrapping malloc - 1.44 (2014-08-07) + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar + Cornut & stb) optimize PNG (ryg) fix bug in interlaced PNG with + user-specified channel count (stb) 1.46 (2014-08-26) fix broken tRNS chunk + (colorkey-style transparency) in non-paletted PNG 1.45 (2014-08-16) fix + MSVC-ARM internal compiler error by wrapping malloc 1.44 (2014-08-07) various warning fixes from Ronny Chevalier 1.43 (2014-07-15) fix MSVC-only compiler problem in code changed in 1.42 @@ -7609,73 +8428,48 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user fixes to stbi__cleanup_jpeg path added STBI_ASSERT to avoid requiring assert.h 1.41 (2014-06-25) - fix search&replace from 1.36 that messed up comments/error messages - 1.40 (2014-06-22) - fix gcc struct-initialization warning - 1.39 (2014-06-15) - fix to TGA optimization when req_comp != number of components in TGA; - fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) - add support for BMP version 5 (more ignored fields) - 1.38 (2014-06-06) - suppress MSVC warnings on integer casts truncating values - fix accidental rename of 'skip' field of I/O - 1.37 (2014-06-04) - remove duplicate typedef - 1.36 (2014-06-03) - convert to header file single-file library - if de-iphone isn't set, load iphone images color-swapped instead of returning NULL - 1.35 (2014-05-27) - various warnings - fix broken STBI_SIMD path - fix bug where stbi_load_from_file no longer left file pointer in correct place - fix broken non-easy path for 32-bit BMP (possibly never used) - TGA optimization by Arseny Kapoulkine - 1.34 (unknown) - use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case - 1.33 (2011-07-14) - make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements - 1.32 (2011-07-13) - support for "info" function for all supported filetypes (SpartanJ) - 1.31 (2011-06-20) - a few more leak fixes, bug in PNG handling (SpartanJ) - 1.30 (2011-06-11) - added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + fix search&replace from 1.36 that messed up comments/error + messages 1.40 (2014-06-22) fix gcc struct-initialization warning 1.39 + (2014-06-15) fix to TGA optimization when req_comp != number of components in + TGA; fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my + test suite) add support for BMP version 5 (more ignored fields) 1.38 + (2014-06-06) suppress MSVC warnings on integer casts truncating values fix + accidental rename of 'skip' field of I/O 1.37 (2014-06-04) remove duplicate + typedef 1.36 (2014-06-03) convert to header file single-file library if + de-iphone isn't set, load iphone images color-swapped instead of returning + NULL 1.35 (2014-05-27) various warnings fix broken STBI_SIMD path fix bug + where stbi_load_from_file no longer left file pointer in correct place fix + broken non-easy path for 32-bit BMP (possibly never used) TGA optimization by + Arseny Kapoulkine 1.34 (unknown) use STBI_NOTUSED in + stbi__resample_row_generic(), fix one more leak in tga failure case 1.33 + (2011-07-14) make stbi_is_hdr work in STBI_NO_HDR (as specified), minor + compiler-friendly improvements 1.32 (2011-07-13) support for "info" function + for all supported filetypes (SpartanJ) 1.31 (2011-06-20) a few more leak + fixes, bug in PNG handling (SpartanJ) 1.30 (2011-06-11) added ability to + load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway - error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) - fix inefficiency in decoding 32-bit BMP (David Woo) - 1.29 (2010-08-16) - various warning fixes from Aurelien Pocheville - 1.28 (2010-08-01) - fix bug in GIF palette transparency (SpartanJ) - 1.27 (2010-08-01) - cast-to-stbi_uc to fix warnings - 1.26 (2010-07-24) - fix bug in file buffering for PNG reported by SpartanJ - 1.25 (2010-07-17) - refix trans_data warning (Won Chun) - 1.24 (2010-07-12) - perf improvements reading from files on platforms with lock-heavy fgetc() - minor perf improvements for jpeg - deprecated type-specific functions so we'll get feedback if they're needed - attempt to fix trans_data warning (Won Chun) - 1.23 fixed bug in iPhone support - 1.22 (2010-07-10) - removed image *writing* support - stbi_info support from Jetro Lauha - GIF support from Jean-Marc Lienher + removed support for installable file formats (stbi_loader) -- + would have been broken for IO callbacks anyway error cases in bmp and tga + give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in + decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from + Aurelien Pocheville 1.28 (2010-08-01) fix bug in GIF palette transparency + (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings 1.26 + (2010-07-24) fix bug in file buffering for PNG reported by SpartanJ 1.25 + (2010-07-17) refix trans_data warning (Won Chun) 1.24 (2010-07-12) perf + improvements reading from files on platforms with lock-heavy fgetc() minor + perf improvements for jpeg deprecated type-specific functions so we'll get + feedback if they're needed attempt to fix trans_data warning (Won Chun) 1.23 + fixed bug in iPhone support 1.22 (2010-07-10) removed image *writing* + support stbi_info support from Jetro Lauha GIF support from Jean-Marc Lienher iPhone PNG-extensions from James Brown - warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) - 1.21 fix use of 'stbi_uc' in header (reported by jon blow) - 1.20 added support for Softimage PIC, by Tom Seddon - 1.19 bug in interlaced PNG corruption check (found by ryg) - 1.18 (2008-08-02) - fix a threading bug (local mutable static) - 1.17 support interlaced PNG - 1.16 major bugfix - stbi__convert_format converted one too many pixels - 1.15 initialize some fields for thread safety - 1.14 fix threadsafe conversion bug - header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. + Janez (U+017D)emva) 1.21 fix use of 'stbi_uc' in header (reported by jon + blow) 1.20 added support for Softimage PIC, by Tom Seddon 1.19 bug in + interlaced PNG corruption check (found by ryg) 1.18 (2008-08-02) fix a + threading bug (local mutable static) 1.17 support interlaced PNG 1.16 + major bugfix - stbi__convert_format converted one too many pixels 1.15 + initialize some fields for thread safety 1.14 fix threadsafe conversion + bug header-file-only version (#define STBI_HEADER_FILE_ONLY before including) 1.13 threadsafe 1.12 const qualifiers in the API 1.11 Support installable IDCT, colorspace conversion routines @@ -7685,15 +8479,14 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz 1.07 attempt to fix C++ warning/errors again 1.06 attempt to fix C++ warning/errors again - 1.05 fix TGA loading to return correct *comp and use good luminance calc - 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free - 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR - 1.02 support for (subset of) HDR files, float interface for preferred access to them - 1.01 fix bug: possible bug in handling right-side up bmps... not sure - fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all - 1.00 interface to zlib that skips zlib header - 0.99 correct handling of alpha in palette - 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 1.05 fix TGA loading to return correct *comp and use good luminance + calc 1.04 default float alpha is 1, not 255; use 'void *' for + stbi_image_free 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR 1.02 support + for (subset of) HDR files, float interface for preferred access to them 1.01 + fix bug: possible bug in handling right-side up bmps... not sure fix bug: the + stbi__bmp_load() and stbi__tga_load() functions didn't work at all 1.00 + interface to zlib that skips zlib header 0.99 correct handling of alpha in + palette 0.98 TGA loader by lonesock; dynamically add loaders (untested) 0.97 jpeg errors on too large a file; also catch another malloc failure 0.96 fix detection of invalid v value - particleman@mollyrocket forum 0.95 during header scan, seek to markers in case of padding @@ -7706,8 +8499,8 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 0.60 fix compiling as c++ 0.59 fix warnings: merge Dave Moore's -Wall fixes 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian - 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available - 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but + less than 16 available 0.56 fix bug: zlib uncompressed mode len vs. nlen 0.55 fix bug: restart_interval not initialized to 0 0.54 allow NULL for 'int *comp' 0.53 fix bug in png 3->4; speedup png decoding @@ -7718,7 +8511,6 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user first released version */ - /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. diff --git a/sample_programs/ml_sample_programs/vision_models/dave_keras/compile.sh b/sample_programs/ml_sample_programs/vision_models/dave_keras/compile.sh index 4532f5b6..1c020d63 100755 --- a/sample_programs/ml_sample_programs/vision_models/dave_keras/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/dave_keras/compile.sh @@ -5,7 +5,7 @@ printf "\n[Compile Script]: Convert TF model to LLVM IR\n" python3 -m tf2onnx.convert --saved-model dave2-keras.tf --output model.onnx #python3 ../../tools/ExtendONNXModel.py model.onnx extendedmodel.onnx --layers="conv;max_pool" > expected_op_seq.txt python3 ../../../../tools/ExtendONNXModel.py --model_path ./model.onnx --output_model_path ./extendedmodel.onnx > expected_op_seq.txt -onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp +onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp mlir-translate -mlir-to-llvmir extendedmodel.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/dave_keras/datagen.py b/sample_programs/ml_sample_programs/vision_models/dave_keras/datagen.py index dbe4ee03..fddf5e0a 100644 --- a/sample_programs/ml_sample_programs/vision_models/dave_keras/datagen.py +++ b/sample_programs/ml_sample_programs/vision_models/dave_keras/datagen.py @@ -5,15 +5,18 @@ @author: berk """ + import numpy as np import os from tensorflow.python.keras.utils.data_utils import Sequence import cv2 + class DataGenerator(Sequence): - 'Generates data for Keras' - def __init__(self, list_IDs, labels, batch_size=32, dim=(66,200,3), shuffle=True): - 'Initialization' + "Generates data for Keras" + + def __init__(self, list_IDs, labels, batch_size=32, dim=(66, 200, 3), shuffle=True): + "Initialization" self.dim = dim self.batch_size = batch_size self.labels = labels @@ -22,13 +25,13 @@ def __init__(self, list_IDs, labels, batch_size=32, dim=(66,200,3), shuffle=True self.on_epoch_end() def __len__(self): - 'Denotes the number of batches per epoch' + "Denotes the number of batches per epoch" return int(np.floor(len(self.list_IDs) / self.batch_size)) def __getitem__(self, index): - 'Generate one batch of data' + "Generate one batch of data" # Generate indexes of the batch - indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] + indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size] # Find list of IDs list_IDs_temp = [self.list_IDs[k] for k in indexes] @@ -39,13 +42,13 @@ def __getitem__(self, index): return X, y def on_epoch_end(self): - 'Updates indexes after each epoch' + "Updates indexes after each epoch" self.indexes = np.arange(len(self.list_IDs)) if self.shuffle == True: np.random.shuffle(self.indexes) def __data_generation(self, list_IDs_temp): - 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) + "Generates data containing batch_size samples" # X : (n_samples, *dim, n_channels) # Initialization X = np.empty((self.batch_size, *self.dim)) y = np.empty(self.batch_size) @@ -53,11 +56,13 @@ def __data_generation(self, list_IDs_temp): # Generate data for i, ID in enumerate(list_IDs_temp): # Store sample - image = cv2.imread(os.getcwd()+"/data/"+str(ID)+".jpg") #read images from disk - image=cv2.resize(image[-150:], (200,66))/255 + image = cv2.imread( + os.getcwd() + "/data/" + str(ID) + ".jpg" + ) # read images from disk + image = cv2.resize(image[-150:], (200, 66)) / 255 X[i,] = image - + # Store class y[i] = float(self.labels[ID]) - return X, y \ No newline at end of file + return X, y diff --git a/sample_programs/ml_sample_programs/vision_models/dave_keras/image.c b/sample_programs/ml_sample_programs/vision_models/dave_keras/image.c index 7a52a815..f1279622 100644 --- a/sample_programs/ml_sample_programs/vision_models/dave_keras/image.c +++ b/sample_programs/ml_sample_programs/vision_models/dave_keras/image.c @@ -157,7 +157,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/vision_models/dave_keras/run.py b/sample_programs/ml_sample_programs/vision_models/dave_keras/run.py index a81b2c21..b13550ff 100644 --- a/sample_programs/ml_sample_programs/vision_models/dave_keras/run.py +++ b/sample_programs/ml_sample_programs/vision_models/dave_keras/run.py @@ -5,6 +5,7 @@ @author: berk """ + from tensorflow.keras.preprocessing.image import img_to_array import scipy.misc import os @@ -15,36 +16,35 @@ from pdb import set_trace input_shape = (66, 200, 3) -angle=[] -smooth_angle=0 -test_ids=[] - +angle = [] +smooth_angle = 0 +test_ids = [] -f= open("data.txt") #read steering angles from disk and preprocess +f = open("data.txt") # read steering angles from disk and preprocess data = f.read() data = data.split() -for i in data: #if the node end with ".jpg" ignore it. It's for collecting angles - if i[-1]=='g': +for i in data: # if the node end with ".jpg" ignore it. It's for collecting angles + if i[-1] == "g": pass else: - angle.append(float(i) * scipy.pi / 180) #convert rad. - -model = load_model("model.h5") #import our model -model.save('dave2-keras.tf/') + angle.append(float(i) * scipy.pi / 180) # convert rad. + +model = load_model("model.h5") # import our model +model.save("dave2-keras.tf/") -#sahin_direksiyon = cv2.imread("sahin_direksiyon_simiti.png") #read steering image +# sahin_direksiyon = cv2.imread("sahin_direksiyon_simiti.png") #read steering image -#set_trace() +# set_trace() -#image=cv2.resize(sahin_direksiyon[-150:], (200,66)) -#image = image / 255 +# image=cv2.resize(sahin_direksiyon[-150:], (200,66)) +# image = image / 255 -#image = img_to_array(image)/255 -#result = -model.predict(image[None])*180.0/scipy.pi #make a prediction -#print("Predicted Angle= " + str(-result)) +# image = img_to_array(image)/255 +# result = -model.predict(image[None])*180.0/scipy.pi #make a prediction +# print("Predicted Angle= " + str(-result)) -''' +""" test_paths = list(paths.list_images(os.getcwd()+"/test")) #get test images ids (names) for i in test_paths: @@ -77,12 +77,7 @@ time.sleep(0.02) if cv2.waitKey(1) & 0xFF == ord("q"): break -''' +""" cv2.destroyAllWindows() - - - - - diff --git a/sample_programs/ml_sample_programs/vision_models/dave_keras/stb_image.h b/sample_programs/ml_sample_programs/vision_models/dave_keras/stb_image.h index accef483..5b891039 100644 --- a/sample_programs/ml_sample_programs/vision_models/dave_keras/stb_image.h +++ b/sample_programs/ml_sample_programs/vision_models/dave_keras/stb_image.h @@ -3,7 +3,8 @@ Do this: #define STB_IMAGE_IMPLEMENTATION - before you include this file in *one* C or C++ file to create the implementation. + before you include this file in *one* C or C++ file to create the +implementation. // i.e. it should look like this: #include ... @@ -13,15 +14,16 @@ #include "stb_image.h" You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. - And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using +malloc,realloc,free QUICK NOTES: Primarily of interest to game developers and other people who can avoid problematic images and only need the trivial interface - JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) - PNG 1/2/4/8/16-bit-per-channel + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as +stock IJG lib) PNG 1/2/4/8/16-bit-per-channel TGA (not sure what subset, if a subset) BMP non-1bpp, non-RLE @@ -50,25 +52,22 @@ RECENT REVISION HISTORY: 2.26 (2020-07-13) many minor fixes 2.25 (2020-02-02) fix warnings - 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically - 2.23 (2019-08-11) fix clang static analysis warning - 2.22 (2019-03-04) gif fixes, fix warnings - 2.21 (2019-02-25) fix typo in comment - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and +flip_vertically 2.23 (2019-08-11) fix clang static analysis warning 2.22 +(2019-03-04) gif fixes, fix warnings 2.21 (2019-02-25) fix typo in comment 2.20 +(2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix warnings 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings - 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes - 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 - RGB-format JPEG; remove white matting in PSD; - allocate large structures on the stack; - correct channel count for PNG & BMP - 2.10 (2016-01-22) avoid warning introduced in 2.09 - 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; +bugfixes 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE +detection on GCC 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for +Imagenet JPGs 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; +fixes 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 +(2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 RGB-format JPEG; remove +white matting in PSD; allocate large structures on the stack; correct channel +count for PNG & BMP 2.10 (2016-01-22) avoid warning introduced in 2.09 2.09 +(2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED See end of file for full revision history. @@ -86,38 +85,37 @@ RECENT REVISION HISTORY: github:urraka (animated gif) Junggon Kim (PNM comments) Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) socks-the-fox (16-bit PNG) - Jeremy Sawicki (handle all ImageNet JPGs) - Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Jeremy Sawicki (handle all ImageNet +JPGs) Optimizations & bugfixes Mikhail Morozov (1-bit BMP) Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) Arseny Kapoulkine John-Mark Allen Carmelo J Fdez-Aguera Bug & warning fixes - Marc LeBlanc David Woo Guillaume George Martins Mozeiko - Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski - Phil Jordan Dave Moore Roy Eltham - Hayaki Saito Nathan Reed Won Chun - Luke Graham Johan Duparc Nick Verigakis the Horde3D community - Thomas Ruf Ronny Chevalier github:rlyeh - Janez Zemva John Bartholomew Michal Cichon github:romigrou - Jonathan Blow Ken Hamada Tero Hanninen github:svdijk - Laurent Gomila Cort Stratton github:snagar - Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex - Cass Everitt Ryamond Barbiero github:grim210 - Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw - Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus - Josh Tobin Matthew Gregan github:poppolopoppo - Julian Raschke Gregory Mullen Christian Floisand github:darealshinji - Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 - Brad Weinberger Matvey Cherevko [reserved] - Luca Sas Alexander Veselov Zack Middleton [reserved] + Marc LeBlanc David Woo Guillaume George Martins +Mozeiko Christpher Lloyd Jerry Jansson Joseph Thomson Blazej +Dariusz Roszkowski Phil Jordan Dave Moore Roy +Eltham Hayaki Saito Nathan Reed Won Chun Luke Graham Johan +Duparc Nick Verigakis the Horde3D community Thomas Ruf Ronny +Chevalier github:rlyeh Janez Zemva John +Bartholomew Michal Cichon github:romigrou Jonathan Blow Ken +Hamada Tero Hanninen github:svdijk Laurent Gomila Cort +Stratton github:snagar Aruelien Pocheville Sergio Gonzalez Thibault +Reuille github:Zelex Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Matthew Gregan +github:poppolopoppo Julian Raschke Gregory Mullen Christian +Floisand github:darealshinji Baldur Karlsson Kevin Schmidt JR +Smith github:Michaelangel007 Brad Weinberger Matvey Cherevko +[reserved] Luca Sas Alexander Veselov Zack Middleton [reserved] Ryan C. Gordon [reserved] [reserved] DO NOT ADD YOUR NAME HERE - To add your name to the credits, pick a random blank space in the middle and fill it. - 80% of merge conflicts on stb PRs are due to people adding their name at the end - of the credits. + To add your name to the credits, pick a random blank space in the middle and +fill it. 80% of merge conflicts on stb PRs are due to people adding their name +at the end of the credits. */ #ifndef STBI_INCLUDE_STB_IMAGE_H @@ -136,14 +134,15 @@ RECENT REVISION HISTORY: // // ... process data if not NULL ... // // ... x = width, y = height, n = # 8-bit components per pixel ... // // ... replace '0' with '1'..'4' to force that many components per pixel -// // ... but 'n' will always be the number that it would have been if you said 0 -// stbi_image_free(data) +// // ... but 'n' will always be the number that it would have been if you +// said 0 stbi_image_free(data) // // Standard parameters: // int *x -- outputs image width in pixels // int *y -- outputs image height in pixels // int *channels_in_file -- outputs # of image components in image file -// int desired_channels -- if non-zero, # of image components requested in result +// int desired_channels -- if non-zero, # of image components requested in +// result // // The return value from an image loader is an 'unsigned char *' which points // to the pixel data, or NULL on an allocation failure or if the image is @@ -171,8 +170,8 @@ RECENT REVISION HISTORY: // and *x, *y, *channels_in_file will be unchanged. The function // stbi_failure_reason() can be queried for an extremely brief, end-user // unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS -// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly -// more user-friendly ones. +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get +// slightly more user-friendly ones. // // Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. // @@ -196,11 +195,12 @@ RECENT REVISION HISTORY: // 2. easy to maintain // 3. good performance // -// Sometimes I let "good performance" creep up in priority over "easy to maintain", -// and for best performance I may provide less-easy-to-use APIs that give higher -// performance, in addition to the easy-to-use ones. Nevertheless, it's important -// to keep in mind that from the standpoint of you, a client of this library, -// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// Sometimes I let "good performance" creep up in priority over "easy to +// maintain", and for best performance I may provide less-easy-to-use APIs that +// give higher performance, in addition to the easy-to-use ones. Nevertheless, +// it's important to keep in mind that from the standpoint of you, a client of +// this library, all you care about is #1 and #3, and stb libraries DO NOT +// emphasize #3 above all. // // Some secondary priorities arise directly from the first two, some of which // provide more explicit reasons why performance can't be emphasized. @@ -219,7 +219,8 @@ RECENT REVISION HISTORY: // overhead. // // The three functions you must define are "read" (reads some bytes of data), -// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the +// end). // // =========================================================================== // @@ -247,10 +248,11 @@ RECENT REVISION HISTORY: // HDR image support (disable by defining STBI_NO_HDR) // // stb_image supports loading HDR images in general, and currently the Radiance -// .HDR file format specifically. You can still load any file through the existing -// interface; if you attempt to load an HDR file, it will be automatically remapped -// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; -// both of these constants can be reconfigured through this interface: +// .HDR file format specifically. You can still load any file through the +// existing interface; if you attempt to load an HDR file, it will be +// automatically remapped to LDR, assuming gamma 2.2 and an arbitrary scale +// factor defaulting to 1; both of these constants can be reconfigured through +// this interface: // // stbi_hdr_to_ldr_gamma(2.2f); // stbi_hdr_to_ldr_scale(1.0f); @@ -342,14 +344,13 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum -{ - STBI_default = 0, // only used for desired_channels +enum { + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -377,11 +378,13 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct -{ - int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof) (void *user); // returns nonzero if we are at end of file/data +typedef struct { + int (*read)(void *user, char *data, + int size); // fill 'data' with 'size' bytes. return number of + // bytes actually read + void (*skip)(void *user, int n); // skip the next 'n' bytes, or 'unget' the + // last -n bytes if negative + int (*eof)(void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -389,21 +392,33 @@ typedef struct // 8-bits-per-channel interface // -STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); -// for stbi_load_from_file, file pointer is left pointing immediately after image +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after +// image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input); #endif //////////////////////////////////// @@ -411,12 +426,20 @@ STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wch // 16-bits-per-channel interface // -STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); #endif //////////////////////////////////// @@ -424,83 +447,102 @@ STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_i // float-per-channel interface // #ifndef STBI_NO_LINEAR - STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); - #ifndef STBI_NO_STDIO - STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); - #endif +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +#endif #endif #ifndef STBI_NO_HDR - STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); - STBIDEF void stbi_hdr_to_ldr_scale(float scale); +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR - STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); - STBIDEF void stbi_ldr_to_hdr_scale(float scale); +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename); -STBIDEF int stbi_is_hdr_from_file(FILE *f); +STBIDEF int stbi_is_hdr(char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO - // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char *stbi_failure_reason (void); +STBIDEF const char *stbi_failure_reason(void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free (void *retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, + int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, + void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit (char const *filename); -STBIDEF int stbi_is_16_bit_from_file(FILE *f); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit(char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif - - // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); // indicate whether we should process iphone images back to canonical format, // or just pass them through "as-is" STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); -// flip the image vertically, so the first pixel in the output array is the bottom left +// flip the image vertically, so the first pixel in the output array is the +// bottom left STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); -// as above, but only applies to images loaded on the thread that calls the function -// this function is only available if your compiler supports thread-local variables; -// calling it will fail to link if your compiler doesn't -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); +// as above, but only applies to images loaded on the thread that calls the +// function this function is only available if your compiler supports +// thread-local variables; calling it will fail to link if your compiler doesn't +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); // ZLIB client - used by PNG, available for other purposes -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header); STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); - -STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, + int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); #ifdef __cplusplus } @@ -513,52 +555,53 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ - || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ - || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ - || defined(STBI_ONLY_ZLIB) - #ifndef STBI_ONLY_JPEG - #define STBI_NO_JPEG - #endif - #ifndef STBI_ONLY_PNG - #define STBI_NO_PNG - #endif - #ifndef STBI_ONLY_BMP - #define STBI_NO_BMP - #endif - #ifndef STBI_ONLY_PSD - #define STBI_NO_PSD - #endif - #ifndef STBI_ONLY_TGA - #define STBI_NO_TGA - #endif - #ifndef STBI_ONLY_GIF - #define STBI_NO_GIF - #endif - #ifndef STBI_ONLY_HDR - #define STBI_NO_HDR - #endif - #ifndef STBI_ONLY_PIC - #define STBI_NO_PIC - #endif - #ifndef STBI_ONLY_PNM - #define STBI_NO_PNM - #endif -#endif - -#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) -#define STBI_NO_ZLIB +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || \ + defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || \ + defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ + defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) +#ifndef STBI_ONLY_JPEG +#define STBI_NO_JPEG +#endif +#ifndef STBI_ONLY_PNG +#define STBI_NO_PNG +#endif +#ifndef STBI_ONLY_BMP +#define STBI_NO_BMP +#endif +#ifndef STBI_ONLY_PSD +#define STBI_NO_PSD +#endif +#ifndef STBI_ONLY_TGA +#define STBI_NO_TGA +#endif +#ifndef STBI_ONLY_GIF +#define STBI_NO_GIF +#endif +#ifndef STBI_ONLY_HDR +#define STBI_NO_HDR +#endif +#ifndef STBI_ONLY_PIC +#define STBI_NO_PIC +#endif +#ifndef STBI_ONLY_PNM +#define STBI_NO_PNM +#endif #endif +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && \ + !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif +#include #include #include // ptrdiff_t on osx #include #include -#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -576,55 +619,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #define STBI_EXTERN extern #endif - #ifndef _MSC_VER - #ifdef __cplusplus - #define stbi_inline inline - #else - #define stbi_inline - #endif +#ifdef __cplusplus +#define stbi_inline inline +#else +#define stbi_inline +#endif #else - #define stbi_inline __forceinline +#define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS - #if defined(__cplusplus) && __cplusplus >= 201103L - #define STBI_THREAD_LOCAL thread_local - #elif defined(__GNUC__) && __GNUC__ < 5 - #define STBI_THREAD_LOCAL __thread - #elif defined(_MSC_VER) - #define STBI_THREAD_LOCAL __declspec(thread) - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) - #define STBI_THREAD_LOCAL _Thread_local - #endif - - #ifndef STBI_THREAD_LOCAL - #if defined(__GNUC__) - #define STBI_THREAD_LOCAL __thread - #endif - #endif +#if defined(__cplusplus) && __cplusplus >= 201103L +#define STBI_THREAD_LOCAL thread_local +#elif defined(__GNUC__) && __GNUC__ < 5 +#define STBI_THREAD_LOCAL __thread +#elif defined(_MSC_VER) +#define STBI_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && \ + !defined(__STDC_NO_THREADS__) +#define STBI_THREAD_LOCAL _Thread_local +#endif + +#ifndef STBI_THREAD_LOCAL +#if defined(__GNUC__) +#define STBI_THREAD_LOCAL __thread +#endif +#endif #endif #ifdef _MSC_VER typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -632,27 +675,30 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL - #define stbi_lrot(x,y) _lrotl(x,y) +#define stbi_lrot(x, y) _lrotl(x, y) #else - #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (32 - (y)))) +#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (32 - (y)))) #endif -#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +#if defined(STBI_MALLOC) && defined(STBI_FREE) && \ + (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) // ok -#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && \ + !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) // ok #else -#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#error \ + "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p,newsz) realloc(p,newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p, newsz) realloc(p, newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) #endif // x86/x64 detection @@ -662,7 +708,8 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI__X86_TARGET #endif -#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && \ + !defined(STBI_NO_SIMD) // gcc doesn't support sse2 intrinsics unless you compile with -msse2, // which in turn means it gets to use SSE2 everywhere. This is unfortunate, // but previous attempts to provide the SSE2 functions with runtime @@ -673,8 +720,10 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI_NO_SIMD #endif -#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) -// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && \ + !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid +// STBI__X64_TARGET // // 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the // Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. @@ -684,44 +733,43 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; // See https://github.com/nothings/stb/issues/81 for more information. // // So default to no SSE2 on 32-bit MinGW. If you've read this far and added -// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +// -mstackrealign to your build settings, feel free to #define +// STBI_MINGW_ENABLE_SSE2. #define STBI_NO_SIMD #endif -#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#if !defined(STBI_NO_SIMD) && \ + (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) #define STBI_SSE2 #include #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) -{ - int info[4]; - __cpuid(info,1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) { + int info[4]; + __cpuid(info, 1); + return info[3]; } #else -static int stbi__cpuid3(void) -{ - int res; - __asm { +static int stbi__cpuid3(void) { + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) { + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -729,12 +777,11 @@ static int stbi__sse2_available(void) #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) { + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -766,188 +813,182 @@ static int stbi__sse2_available(void) // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct -{ - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct { + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void *io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; - static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) -{ - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) { + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) -{ - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, + void *user) { + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void *user, char *data, int size) -{ - return (int) fread(data,1,size,(FILE*) user); +static int stbi__stdio_read(void *user, char *data, int size) { + return (int)fread(data, 1, size, (FILE *)user); } -static void stbi__stdio_skip(void *user, int n) -{ - int ch; - fseek((FILE*) user, n, SEEK_CUR); - ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ - } +static void stbi__stdio_skip(void *user, int n) { + int ch; + fseek((FILE *)user, n, SEEK_CUR); + ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ + } } -static int stbi__stdio_eof(void *user) -{ - return feof((FILE*) user) || ferror((FILE *) user); +static int stbi__stdio_eof(void *user) { + return feof((FILE *)user) || ferror((FILE *)user); } -static stbi_io_callbacks stbi__stdio_callbacks = -{ - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static stbi_io_callbacks stbi__stdio_callbacks = { + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context *s, FILE *f) -{ - stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +static void stbi__start_file(stbi__context *s, FILE *f) { + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } -//static void stop_file(stbi__context *s) { } +// static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context *s) -{ - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 + // bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum -{ - STBI_ORDER_RGB, - STBI_ORDER_BGR -}; +enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; -typedef struct -{ - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct { + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context *s); -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context *s); -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__png_is16(stbi__context *s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context *s); -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context *s); -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s); -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__psd_is16(stbi__context *s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context *s); -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context *s); -static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context *s); -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s); -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); #endif static #ifdef STBI_THREAD_LOCAL -STBI_THREAD_LOCAL + STBI_THREAD_LOCAL #endif -const char *stbi__g_failure_reason; + const char *stbi__g_failure_reason; -STBIDEF const char *stbi_failure_reason(void) -{ - return stbi__g_failure_reason; +STBIDEF const char *stbi_failure_reason(void) { + return stbi__g_failure_reason; } #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char *str) -{ - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) { + stbi__g_failure_reason = str; + return 0; } #endif -static void *stbi__malloc(size_t size) -{ - return STBI_MALLOC(size); +static void *stbi__malloc(size_t size) { + return STBI_MALLOC(size); } // stb_image uses ints pervasively, including for offset calculations. @@ -962,70 +1003,72 @@ static void *stbi__malloc(size_t size) // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) -{ - if (b < 0) return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) { + if (b < 0) + return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) -{ - if (a < 0 || b < 0) return 0; - if (b == 0) return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX/b; +static int stbi__mul2sizes_valid(int a, int b) { + if (a < 0 || b < 0) + return 0; + if (b == 0) + return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; } -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__addsizes_valid(a*b*c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__addsizes_valid(a * b * c, add); } -// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't +// overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__mul2sizes_valid(a * b * c, d) && + stbi__addsizes_valid(a * b * c * d, add); } #endif -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void *stbi__malloc_mad2(int a, int b, int add) -{ - if (!stbi__mad2sizes_valid(a, b, add)) return NULL; - return stbi__malloc(a*b + add); +static void *stbi__malloc_mad2(int a, int b, int add) { + if (!stbi__mad2sizes_valid(a, b, add)) + return NULL; + return stbi__malloc(a * b + add); } #endif -static void *stbi__malloc_mad3(int a, int b, int c, int add) -{ - if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; - return stbi__malloc(a*b*c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) { + if (!stbi__mad3sizes_valid(a, b, c, add)) + return NULL; + return stbi__malloc(a * b * c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) -{ - if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; - return stbi__malloc(a*b*c*d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) { + if (!stbi__mad4sizes_valid(a, b, c, d, add)) + return NULL; + return stbi__malloc(a * b * c * d + add); } #endif @@ -1034,417 +1077,459 @@ static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS - #define stbi__err(x,y) 0 +#define stbi__err(x, y) 0 #elif defined(STBI_FAILURE_USERMSG) - #define stbi__err(x,y) stbi__err(y) +#define stbi__err(x, y) stbi__err(y) #else - #define stbi__err(x,y) stbi__err(x) +#define stbi__err(x, y) stbi__err(x) #endif -#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) -#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpuc(x, y) \ + ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -STBIDEF void stbi_image_free(void *retval_from_stbi_load) -{ - STBI_FREE(retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load) { + STBI_FREE(retval_from_stbi_load); } #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else -static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, + stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ - ? stbi__vertically_flip_on_load_local \ - : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load \ + (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; - - #ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNG - if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_GIF - if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PSD - if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); - #else - STBI_NOTUSED(bpc); - #endif - #ifndef STBI_NO_PIC - if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); - #endif - - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } - #endif - - #ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s,x,y,comp,req_comp, ri); - #endif - - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); -} - -static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi_uc *reduced; - - reduced = (stbi_uc *) stbi__malloc(img_len); - if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - - STBI_FREE(orig); - return reduced; -} - -static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi__uint16 *enlarged; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = + 8; // default is 8 so most paths don't have to be changed + ri->channel_order = + STBI_ORDER_RGB; // all current input & output are this, but this is here + // so we can add BGR order + ri->num_channels = 0; - enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); - if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); +#ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) + return stbi__jpeg_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNG + if (stbi__png_test(s)) + return stbi__png_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) + return stbi__bmp_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_GIF + if (stbi__gif_test(s)) + return stbi__gif_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PSD + if (stbi__psd_test(s)) + return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); +#else + STBI_NOTUSED(bpc); +#endif +#ifndef STBI_NO_PIC + if (stbi__pic_test(s)) + return stbi__pic_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) + return stbi__pnm_load(s, x, y, comp, req_comp, ri); +#endif - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } +#endif - STBI_FREE(orig); - return enlarged; -} +#ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s, x, y, comp, req_comp, ri); +#endif -static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) -{ - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc *bytes = (stbi_uc *)image; - - for (row = 0; row < (h>>1); row++) { - stbi_uc *row0 = bytes + row*bytes_per_row; - stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } + return stbi__errpuc("unknown image type", + "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *)stbi__malloc(img_len); + if (reduced == NULL) + return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = + (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient + // approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); + if (enlarged == NULL) + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + + orig[i]); // replicate to high and low byte, + // maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, + int bytes_per_pixel) { + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h >> 1); row++) { + stbi_uc *row0 = bytes + row * bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1) * bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = + (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) -{ - int slice; - int slice_size = w * h * bytes_per_pixel; +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, + int bytes_per_pixel) { + int slice; + int slice_size = w * h * bytes_per_pixel; - stbi_uc *bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *) result; + return (unsigned char *)result; } -static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to + // keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *) result; + return (stbi__uint16 *)result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) -{ - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, + int req_comp) { + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar( + unsigned int cp, unsigned long flags, const char *str, int cbmb, + wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte( + unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, + char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) -{ - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input) { + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, + (int)bufferlen, NULL, NULL); } #endif -static FILE *stbi__fopen(char const *filename, char const *mode) -{ - FILE *f; +static FILE *stbi__fopen(char const *filename, char const *mode) { + FILE *f; #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, + sizeof(wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) - return 0; + if (0 == + MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) + return 0; #if _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f=0; + if (0 != fopen_s(&f, filename, mode)) + f = 0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; -} - - -STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - unsigned char *result; - if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; -} - -STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__uint16 *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - stbi__uint16 *result; - if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f,x,y,comp,req_comp); - fclose(f); - return result; -} - - -#endif //!STBI_NO_STDIO - -STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); -} - -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + return f; +} + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) + return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) + return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +#endif //! STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_mem(&s,buffer,len); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_mem(&s, buffer, len); - result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); - } + result = + (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices(result, *x, *y, *z, *comp); + } - return result; + return result; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *data; - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data,x,y,comp,req_comp); - return hdr_data; - } - #endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); -} - -STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__loadf_main(&s,x,y,comp,req_comp); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp) { + unsigned char *data; +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data, x, y, comp, req_comp); + return hdr_data; + } +#endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", + "Image not of any known type, or corrupt"); } -STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__loadf_main(&s,x,y,comp,req_comp); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -#ifndef STBI_NO_STDIO -STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - float *result; - FILE *f = stbi__fopen(filename, "rb"); - if (!f) return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_file(&s,f); - return stbi__loadf_main(&s,x,y,comp,req_comp); +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, + int req_comp) { + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) + return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_file(&s, f); + return stbi__loadf_main(&s, x, y, comp, req_comp); } #endif // !STBI_NO_STDIO @@ -1454,221 +1539,222 @@ STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_ // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; +#endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result=0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result = 0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE *f) -{ - #ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s,f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; - #else - STBI_NOTUSED(f); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) { +#ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s, f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; +#else + STBI_NOTUSED(f); + return 0; +#endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; +#endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; +static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { + stbi__l2h_gamma = gamma; +} +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { + stbi__l2h_scale = scale; +} #endif -static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; - -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } +static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { + stbi__h2l_gamma_i = 1 / gamma; +} +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { + stbi__h2l_scale_i = 1 / scale; +} ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum -{ - STBI__SCAN_load=0, - STBI__SCAN_type, - STBI__SCAN_header -}; - -static void stbi__refill_buffer(stbi__context *s) -{ - int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); - s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start+1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } -} - -stbi_inline static stbi_uc stbi__get8(stbi__context *s) -{ - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; -} - -#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; + +static void stbi__refill_buffer(stbi__context *s) { + int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context *s) -{ - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) + return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) + return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context *s, int n) -{ - if (n == 0) return; // already there! - if (n < 0) { +static void stbi__skip(stbi__context *s, int n) { + if (n == 0) + return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); return; - } - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; + } + } + s->img_buffer += n; } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && \ + defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) -{ - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) { + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; - memcpy(buffer, s->img_buffer, blen); + memcpy(buffer, s->img_buffer, blen); - count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); - res = (count == (n-blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } + count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); + res = (count == (n - blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } - if (s->img_buffer+n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + if (s->img_buffer + n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context *s) -{ - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) { + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context *s) -{ - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) { + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context *s) -{ - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) { + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context *s) -{ - stbi__uint32 z = stbi__get16le(s); - return z + (stbi__get16le(s) << 16); +static stbi__uint32 stbi__get32le(stbi__context *s) { + stbi__uint32 z = stbi__get16le(s); + return z + (stbi__get16le(s) << 16); } #endif -#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) \ + ((stbi_uc)((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1682,169 +1768,301 @@ static stbi__uint32 stbi__get32le(stbi__context *s) // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) -{ - return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi_uc stbi__compute_y(int r, int g, int b) { + return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - unsigned char *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - unsigned char *src = data + j * x * img_n ; - unsigned char *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + unsigned char *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + unsigned char *src = data + j * x * img_n; + unsigned char *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 255; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 255; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 255; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = 255; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) -{ - return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { + return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - stbi__uint16 *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - stbi__uint16 *src = data + j * x * img_n ; - stbi__uint16 *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + stbi__uint16 *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + stbi__uint16 *src = data + j * x * img_n; + stbi__uint16 *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 0xffff; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 0xffff; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 0xffff; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = 0xffff; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return (stbi__uint16 *)stbi__errpuc("unsupported", + "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) -{ - int i,k,n; - float *output; - if (!data) return NULL; - output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i=0; i < x*y; ++i) { - output[i*comp + n] = data[i*comp + n]/255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) { + int i, k, n; + float *output; + if (!data) + return NULL; + output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpf("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + output[i * comp + k] = + (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * + stbi__l2h_scale); + } + } + if (n < comp) { + for (i = 0; i < x * y; ++i) { + output[i * comp + n] = data[i * comp + n] / 255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int) (x)) -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) -{ - int i,k,n; - stbi_uc *output; - if (!data) return NULL; - output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - if (k < comp) { - float z = data[i*comp+k] * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int)(x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) { + int i, k, n; + stbi_uc *output; + if (!data) + return NULL; + output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, + stbi__h2l_gamma_i) * + 255 + + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + if (k < comp) { + float z = data[i * comp + k] * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1872,750 +2090,791 @@ static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache - -typedef struct -{ - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct { + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct -{ - stbi__context *s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; - -// sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; - -// definition of jpeg image component - struct - { - int id; - int h,v; - int tq; - int hd,ha; - int dc_pred; - - int x,y,w2,h2; - stbi_uc *data; - void *raw_data, *raw_coeff; - stbi_uc *linebuf; - short *coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; - - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop - - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; - - int scan_n, order[4]; - int restart_interval, todo; - -// kernels - void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); - stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +typedef struct { + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, int count, + int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman *h, int *count) -{ - int i,j,k=0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i=0; i < 16; ++i) - for (j=0; j < count[i]; ++j) - h->size[k++] = (stbi_uc) (i+1); - h->size[k] = 0; - - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for(j=1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16) (code++); - if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16-j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; - - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i=0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS-s); - int m = 1 << (FAST_BITS-s); - for (j=0; j < m; ++j) { - h->fast[c+j] = (stbi_uc) i; - } +static int stbi__build_huffman(stbi__huffman *h, int *count) { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) + for (j = 0; j < count[i]; ++j) + h->size[k++] = (stbi_uc)(i + 1); + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16)(code++); + if (code - 1 >= (1u << j)) + return stbi__err("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (stbi_uc)i; } - } - return 1; + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) -{ - int i; - for (i=0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; - - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); - } +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) + k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); } - } -} - -static void stbi__grow_buffer_unsafe(stbi__jpeg *j) -{ - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char) c; - j->nomore = 1; - return; - } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) { + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) + c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; +static const stbi__uint32 stbi__bmask[17] = { + 0, 1, 3, 7, 15, 31, 63, 127, 255, + 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) -{ - unsigned int temp; - int c,k; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } - - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k=FAST_BITS+1 ; ; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } - - if (k > j->code_bits) +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) return -1; - - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1;; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & + stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); - - sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB - k = stbi_lrot(j->code_buffer, n); - if (n < 0 || n >= (int) (sizeof(stbi__bmask)/sizeof(*stbi__bmask))) return 0; - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & ~sgn); +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) { + unsigned int k; + int sgn; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + + sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB + k = stbi_lrot(j->code_buffer, n); + if (n < 0 || n >= (int)(sizeof(stbi__bmask) / sizeof(*stbi__bmask))) + return 0; + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & ~sgn); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) -{ - unsigned int k; - if (j->code_bits < n) stbi__grow_buffer_unsafe(j); - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; -} - -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) -{ - unsigned int k; - if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) { + unsigned int k; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) { + unsigned int k; + if (j->code_bits < 1) + stbi__grow_buffer_unsafe(j); + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64+15] = -{ - 0, 1, 8, 16, 9, 2, 3, 10, - 17, 24, 32, 25, 18, 11, 4, 5, - 12, 19, 26, 33, 40, 48, 41, 34, - 27, 20, 13, 6, 7, 14, 21, 28, - 35, 42, 49, 56, 57, 50, 43, 36, - 29, 22, 15, 23, 30, 37, 44, 51, - 58, 59, 52, 45, 38, 31, 39, 46, - 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, - 63, 63, 63, 63, 63, 63, 63 -}; +static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { + 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, + 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, + 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) -{ - int diff,dc,k; - int t; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - - // 0 all the ac values now so we can do it 32-bits at a time - memset(data,0,64*sizeof(data[0])); - - diff = t ? stbi__extend_receive(j, t) : 0; - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc * dequant[0]); - - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) * dequant[zig]); +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, stbi__huffman *hac, + stbi__int16 *fac, int b, + stbi__uint16 *dequant) { + int diff, dc, k; + int t; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) + break; // end block + k += 16; } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (rs != 0xf0) break; // end block - k += 16; - } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); - } + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); } - } while (k < 64); - return 1; -} - -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) -{ - int diff,dc; - int t; - if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data,0,64*sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t == -1) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; - - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc << j->succ_low); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short) (1 << j->succ_low); - } - return 1; + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, int b) { + int diff, dc; + int t; + if (j->spec_end != 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t == -1) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc << j->succ_low); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; } // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) -{ - int k; - if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->succ_high == 0) { - int shift = j->succ_low; +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], + stbi__huffman *hac, + stbi__int16 *fac) { + int k; + if (j->spec_start == 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } - if (j->eob_run) { - --j->eob_run; - return 1; + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) << shift); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) << shift); + } } - + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { k = j->spec_start; do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) << shift); - } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) << shift); - } - } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients - - short bit = (short) (1 << j->succ_low); - - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short *p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } - } else { - k = j->spec_start; - do { - int r,s; - int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } + int r, s; + int rs = stbi__jpeg_huff_decode( + j, hac); // @OPTIMIZE see if we can use the fast path here, + // advance-by-r is so slow, eh + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) + return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } - // advance by r - while (k <= j->spec_end) { - short *p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short) s; - break; - } - --r; - } + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short)s; + break; } - } while (k <= j->spec_end); - } - } - return 1; + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) -{ - // trick to use a single test to catch both cases - if ((unsigned int) x > 255) { - if (x < 0) return 0; - if (x > 255) return 255; - } - return (stbi_uc) x; +stbi_inline static stbi_uc stbi__clamp(int x) { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) + return 0; + if (x > 255) + return 255; + } + return (stbi_uc)x; } -#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) -#define stbi__fsh(x) ((x) * 4096) +#define stbi__f2f(x) ((int)(((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ - int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2+p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3*stbi__f2f(-1.847759065f); \ - t3 = p1 + p2*stbi__f2f( 0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2+p3); \ - t1 = stbi__fsh(p2-p3); \ - x0 = t0+t3; \ - x3 = t0-t3; \ - x1 = t1+t2; \ - x2 = t1-t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0+t2; \ - p4 = t1+t3; \ - p1 = t0+t3; \ - p2 = t1+t2; \ - p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ - t0 = t0*stbi__f2f( 0.298631336f); \ - t1 = t1*stbi__f2f( 2.053119869f); \ - t2 = t2*stbi__f2f( 3.072711026f); \ - t3 = t3*stbi__f2f( 1.501321110f); \ - p1 = p5 + p1*stbi__f2f(-0.899976223f); \ - p2 = p5 + p2*stbi__f2f(-2.562915447f); \ - p3 = p3*stbi__f2f(-1.961570560f); \ - p4 = p4*stbi__f2f(-0.390180644f); \ - t3 += p1+p4; \ - t2 += p2+p3; \ - t1 += p2+p4; \ - t0 += p1+p3; - -static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) -{ - int i,val[64],*v=val; - stbi_uc *o; - short *d = data; - - // columns - for (i=0; i < 8; ++i,++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 - && d[40]==0 && d[48]==0 && d[56]==0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0]*4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; x1 += 512; x2 += 512; x3 += 512; - v[ 0] = (x0+t3) >> 10; - v[56] = (x0-t3) >> 10; - v[ 8] = (x1+t2) >> 10; - v[48] = (x1-t2) >> 10; - v[16] = (x2+t1) >> 10; - v[40] = (x2-t1) >> 10; - v[24] = (x3+t0) >> 10; - v[32] = (x3-t0) >> 10; - } - } - - for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128<<17); - x1 += 65536 + (128<<17); - x2 += 65536 + (128<<17); - x3 += 65536 + (128<<17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0+t3) >> 17); - o[7] = stbi__clamp((x0-t3) >> 17); - o[1] = stbi__clamp((x1+t2) >> 17); - o[6] = stbi__clamp((x1-t2) >> 17); - o[2] = stbi__clamp((x2+t1) >> 17); - o[5] = stbi__clamp((x2-t1) >> 17); - o[3] = stbi__clamp((x3+t0) >> 17); - o[4] = stbi__clamp((x3-t0) >> 17); - } +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ + t0 += p1 + p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) { + int i, val[64], *v = val; + stbi_uc *o; + short *d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && + d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; + x1 += 512; + x2 += 512; + x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0 + t3) >> 17); + o[7] = stbi__clamp((x0 - t3) >> 17); + o[1] = stbi__clamp((x1 + t2) >> 17); + o[6] = stbi__clamp((x1 - t2) >> 17); + o[2] = stbi__clamp((x2 + t1) >> 17); + o[5] = stbi__clamp((x2 - t1) >> 17); + o[3] = stbi__clamp((x3 + t0) >> 17); + o[4] = stbi__clamp((x3 - t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; - - // dot product constant: even elems=x, odd elems=y - #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) - - // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) - // out(1) = c1[even]*x + c1[odd]*y - #define dct_rot(out0,out1, x,y,c0,c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) - - // out = in << 12 (in 16-bit, out 32-bit) - #define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - - // wide add - #define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - - // wide sub - #define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) - - // butterfly a/b, add bias, then shift by "s" and pack - #define dct_bfly32o(out0, out1, a,b,bias,s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } - - // 8-bit interleave step (for transposes) - #define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) - - // 16-bit interleave step (for transposes) - #define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) - - #define dct_pass(bias,shift) \ - { \ - /* even part */ \ - dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ - dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0,row7, x0,x7,bias,shift); \ - dct_bfly32o(row1,row6, x1,x6,bias,shift); \ - dct_bfly32o(row2,row5, x2,x5,bias,shift); \ - dct_bfly32o(row3,row4, x3,x4,bias,shift); \ - } +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + +// dot product constant: even elems=x, odd elems=y +#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + +// out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = \ + _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = \ + _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); - - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); - - // load - row0 = _mm_load_si128((const __m128i *) (data + 0*8)); - row1 = _mm_load_si128((const __m128i *) (data + 1*8)); - row2 = _mm_load_si128((const __m128i *) (data + 2*8)); - row3 = _mm_load_si128((const __m128i *) (data + 3*8)); - row4 = _mm_load_si128((const __m128i *) (data + 4*8)); - row5 = _mm_load_si128((const __m128i *) (data + 5*8)); - row6 = _mm_load_si128((const __m128i *) (data + 6*8)); - row7 = _mm_load_si128((const __m128i *) (data + 7*8)); - - // column pass - dct_pass(bias_0, 10); - - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); - - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); - - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } - - // row pass - dct_pass(bias_1, 17); - - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); - - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... - - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... - - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... +// wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - // store - _mm_storel_epi64((__m128i *) out, p0); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p2); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p1); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p3); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); - } +// wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + +// butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = \ + _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = \ + _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + +// 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + +// 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), + stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), + stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), + stbi__f2f(1.175875602f)); + __m128i rot1_1 = + dct_const(stbi__f2f(1.175875602f), + stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), + stbi__f2f(-1.961570560f)); + __m128i rot2_1 = + dct_const(stbi__f2f(-1.961570560f), + stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), + stbi__f2f(-0.390180644f)); + __m128i rot3_1 = + dct_const(stbi__f2f(-0.390180644f), + stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); + row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); + row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); + row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); + row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); + row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); + row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); + row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *)out, p0); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p2); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p1); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p3); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2634,198 +2893,236 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); - -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) - -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) - -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ - dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ - dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ - dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ - } - - // load - row0 = vld1q_s16(data + 0*8); - row1 = vld1q_s16(data + 1*8); - row2 = vld1q_s16(data + 2*8); - row3 = vld1q_s16(data + 3*8); - row4 = vld1q_s16(data + 4*8); - row5 = vld1q_s16(data + 5*8); - row6 = vld1q_s16(data + 6*8); - row7 = vld1q_s16(data + 7*8); - - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - - // column pass - dct_pass(vrshrn_n_s32, 10); - - // 16bit 8x8 transpose - { +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ + } + + // load + row0 = vld1q_s16(data + 0 * 8); + row1 = vld1q_s16(data + 1 * 8); + row2 = vld1q_s16(data + 2 * 8); + row3 = vld1q_s16(data + 3 * 8); + row4 = vld1q_s16(data + 4 * 8); + row5 = vld1q_s16(data + 5 * 8); + row6 = vld1q_s16(data + 6 * 8); + row7 = vld1q_s16(data + 7 * 8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } -#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } - - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); - - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); - - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = \ + vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ + } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } - - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); - - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } -#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } - - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! - - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); - - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); - - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); - - // store - vst1_u8(out, p0); out += out_stride; - vst1_u8(out, p1); out += out_stride; - vst1_u8(out, p2); out += out_stride; - vst1_u8(out, p3); out += out_stride; - vst1_u8(out, p4); out += out_stride; - vst1_u8(out, p5); out += out_stride; - vst1_u8(out, p6); out += out_stride; - vst1_u8(out, p7); + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ + } +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ + } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); + out += out_stride; + vst1_u8(out, p1); + out += out_stride; + vst1_u8(out, p2); + out += out_stride; + vst1_u8(out, p3); + out += out_stride; + vst1_u8(out, p4); + out += out_stride; + vst1_u8(out, p5); + out += out_stride; + vst1_u8(out, p6); + out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -2838,1132 +3135,1274 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg *j) -{ - stbi_uc x; - if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } - x = stbi__get8(j->s); - if (x != 0xff) return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) { + stbi_uc x; + if (j->marker != STBI__MARKER_none) { + x = j->marker; + j->marker = STBI__MARKER_none; + return x; + } + x = stbi__get8(j->s); + if (x != 0xff) + return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg *j) -{ - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels -} - -static int stbi__parse_entropy_coded_data(stbi__jpeg *z) -{ - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i,j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - STBI_SIMD_ALIGN(short, data[64]); - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x)*8; - int y2 = (j*z->img_comp[n].v + y)*8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; +static void stbi__jpeg_reset(stbi__jpeg *j) { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = + j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) { + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } else { - if (z->scan_n == 1) { - int i,j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x); - int y2 = (j*z->img_comp[n].v + y); - short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } + return 1; + } else { // interleaved + int i, j, k, x, y; + STBI_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, + z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + + z->img_comp[n].w2 * y2 + x2, + z->img_comp[n].w2, data); + } } - } - return 1; + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) -{ - int i; - for (i=0; i < 64; ++i) - data[i] *= dequant[i]; -} - -static void stbi__jpeg_finish(stbi__jpeg *z) -{ - if (z->progressive) { - // dequantize and idct the data - int i,j,n; - for (n=0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - } - } + return 1; + } + } else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], + z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static int stbi__process_marker(stbi__jpeg *z, int m) -{ - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker","Corrupt JPEG"); - - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; - - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s)-2; - while (L > 0) { - int q = stbi__get8(z->s); - int p = q >> 4, sixteen = (p != 0); - int t = q & 15,i; - if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); - if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - - for (i=0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); - L -= (sixteen ? 129 : 65); - } - return L==0; - - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s)-2; - while (L > 0) { - stbi_uc *v; - int sizes[16],i,n=0; - int q = stbi__get8(z->s); - int tc = q >> 4; - int th = q & 15; - if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); - for (i=0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; - } - L -= 17; - if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; - v = z->huff_dc[th].values; - } else { - if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; - v = z->huff_ac[th].values; + return 1; + } else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } } - for (i=0; i < n; ++i) - v[i] = stbi__get8(z->s); - if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); - L -= n; - } - return L==0; - } - - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len","Corrupt JPEG"); - else - return stbi__err("bad APP len","Corrupt JPEG"); + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - L -= 2; - - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J','F','I','F','\0'}; - int ok = 1; - int i; - for (i=0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; - int ok = 1; - int i; - for (i=0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + } } + } + } +} - stbi__skip(z->s, L); - return 1; - } +static int stbi__process_marker(stbi__jpeg *z, int m) { + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker", "Corrupt JPEG"); - return stbi__err("unknown marker","Corrupt JPEG"); -} + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) + return stbi__err("bad DRI len", "Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; -// after we see SOS -static int stbi__process_scan_header(stbi__jpeg *z) -{ - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); - if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); - for (i=0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s) - 2; + while (L > 0) { int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) return 0; // no match - z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); - z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); - z->order[i] = which; - } - - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) + return stbi__err("bad DQT type", "Corrupt JPEG"); + if (t > 3) + return stbi__err("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = + (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + stbi_uc *v; + int sizes[16], i, n = 0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) + return stbi__err("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc + th, sizes)) + return 0; + v = z->huff_dc[th].values; } else { - if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); - z->spec_end = 63; + if (!stbi__build_huffman(z->huff_ac + th, sizes)) + return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len", "Corrupt JPEG"); + else + return stbi__err("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; } - } + } + + stbi__skip(z->s, L); + return 1; + } - return 1; + return stbi__err("unknown marker", "Corrupt JPEG"); } -static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) -{ - int i; - for (i=0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) { + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) + return stbi__err("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) + return stbi__err("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) + return 0; // no match + z->img_comp[which].hd = q >> 4; + if (z->img_comp[which].hd > 3) + return stbi__err("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; + if (z->img_comp[which].ha > 3) + return stbi__err("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || + z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; } -static int stbi__process_frame_header(stbi__jpeg *z, int scan) -{ - stbi__context *s = z->s; - int Lf,p,i,q, h_max=1,v_max=1,c; - Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG - p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); - s->img_n = c; - for (i=0; i < c; ++i) { +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } - - if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); - - z->rgb = 0; - for (i=0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = { 'R', 'G', 'B' }; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); - z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); - } - - if (scan != STBI__SCAN_load) return 1; - - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); - - for (i=0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; - } - - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; - - for (i=0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); - } - } + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) { + stbi__context *s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = stbi__get16be(s); + if (Lf < 11) + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + p = stbi__get8(s); + if (p != 8) + return stbi__err("only 8-bit", + "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); + if (s->img_y == 0) + return stbi__err( + "no header height", + "JPEG format not supported: delayed height"); // Legal, but we don't + // handle it--but neither + // does IJG + s->img_x = stbi__get16be(s); + if (s->img_x == 0) + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) + return stbi__err("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) + return stbi__err("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = {'R', 'G', 'B'}; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); + if (!z->img_comp[i].h || z->img_comp[i].h > 4) + return stbi__err("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; + if (!z->img_comp[i].v || z->img_comp[i].v > 4) + return stbi__err("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); + if (z->img_comp[i].tq > 3) + return stbi__err("bad TQ", "Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) + return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) + return stbi__err("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) + h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) + v_max = z->img_comp[i].v; + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked + // earlier) so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = + stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i + 1, + stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = + (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3( + z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components( + z, i + 1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = + (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } - return 1; + return 1; } // use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) - -#define stbi__SOF_progressive(x) ((x) == 0xc2) - -static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) -{ - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); - if (scan == STBI__SCAN_type) return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z,m)) return 0; +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) + return stbi__err("no SOI", "Corrupt JPEG"); + if (scan == STBI__SCAN_type) + return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z, m)) + return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) + return stbi__err("no SOF", "Corrupt JPEG"); m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) return 0; - return 1; + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) + return 0; + return 1; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg *j) -{ - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) return 0; - if (!stbi__parse_entropy_coded_data(j)) return 0; - if (j->marker == STBI__MARKER_none ) { - // handle 0s at the end of image data from IP Kamera 9060 - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - if (x == 255) { - j->marker = stbi__get8(j->s); - break; - } - } - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); - } else { - if (!stbi__process_marker(j, m)) return 0; +static int stbi__decode_jpeg_image(stbi__jpeg *j) { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) + return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) + return 0; + if (!stbi__parse_entropy_coded_data(j)) + return 0; + if (j->marker == STBI__MARKER_none) { + // handle 0s at the end of image data from IP Kamera 9060 + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + if (x == 255) { + j->marker = stbi__get8(j->s); + break; + } + } + // if we reach eof without hitting a marker, stbi__get_marker() below + // will fail and we'll eventually return 0 } - m = stbi__get_marker(j); - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) + return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) + return stbi__err("bad DNL height", "Corrupt JPEG"); + } else { + if (!stbi__process_marker(j, m)) + return 0; + } + m = stbi__get_marker(j); + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, - int w, int hs); - -#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) - -static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; -} - -static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i=0; i < w; ++i) - out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); - return out; + int w, int hs); + +#define stbi__div4(x) ((stbi_uc)((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, + int w, int hs) { + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc *stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc *stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = stbi__div4(n + input[i - 1]); + out[i * 2 + 1] = stbi__div4(n + input[i + 1]); + } + out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc)((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = stbi__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; } -static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples horizontally for every one in input - int i; - stbi_uc *input = in_near; - - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } - - out[0] = input[0]; - out[1] = stbi__div4(input[0]*3 + input[1] + 2); - for (i=1; i < w-1; ++i) { - int n = 3*input[i]+2; - out[i*2+0] = stbi__div4(n+input[i-1]); - out[i*2+1] = stbi__div4(n+input[i+1]); - } - out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); - out[i*2+1] = input[w-1]; - - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); - - return out; -} +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = + _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *)(out + i * 2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = + vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif -#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } -static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i,t0,t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - t1 = 3*in_near[0] + in_far[0]; - out[0] = stbi__div4(t1+2); - for (i=1; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } +#endif -#if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i=0,t0,t1; - - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } - - t1 = 3*in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w-1) & ~7); i += 8) { -#if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); - - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); - - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *) (out + i*2), outv); -#elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); - - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i*2, o); -#endif - - // "previous" value for next iter - t1 = 3*in_near[i+7] + in_far[i+7]; - } - - t0 = t1; - t1 = 3*in_near[i] + in_far[i]; - out[i*2] = stbi__div16(3*t1 + t0 + 8); - - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); - - STBI_NOTUSED(hs); - - return out; -} -#endif - -static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // resample with nearest-neighbor - int i,j; - STBI_NOTUSED(in_far); - for (i=0; i < w; ++i) - for (j=0; j < hs; ++j) - out[i*hs+j] = in_near[i]; - return out; +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // resample with nearest-neighbor + int i, j; + STBI_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) -{ - int i; - for (i=0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int)((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, + int count, int step) { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) -{ - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, + stbi_uc const *pcb, stbi_uc const *pcr, + int count, int step) { + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); - __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); - __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); - __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); - __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); - __m128i xw = _mm_set1_epi16(255); // alpha channel - - for (; i+7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); - - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); - - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); - - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); - - // store - _mm_storeu_si128((__m128i *) (out + 0), o0); - _mm_storeu_si128((__m128i *) (out + 16), o1); - out += 32; - } - } + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *)(out + 0), o0); + _mm_storeu_si128((__m128i *)(out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); - int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); - int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); - int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); - - for (; i+7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); - - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); - - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); - - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8*4; - } - } -#endif - - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + cr * -stbi__float2fixed(0.71414f) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg *j) -{ - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) { + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg *j) -{ - stbi__free_jpeg_components(j, j->s->img_n, 0); +static void stbi__cleanup_jpeg(stbi__jpeg *j) { + stbi__free_jpeg_components(j, j->s->img_n, 0); } -typedef struct -{ - resample_row_func resample; - stbi_uc *line0,*line1; - int hs,vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +typedef struct { + resample_row_func resample; + stbi_uc *line0, *line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) -{ - unsigned int t = x*y + 128; - return (stbi_uc) ((t + (t >>8)) >> 8); -} - -static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) -{ - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe - - // validate req_comp - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } - - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; - - // resample and color-convert - { - int k; - unsigned int i,j; - stbi_uc *output; - stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; - - stbi__resample res_comp[4]; - - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs-1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; - else r->resample = stbi__resample_row_generic; +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { + unsigned int t = x * y + 128; + return (stbi_uc)((t + (t >> 8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, + int *comp, int req_comp) { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && + (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // resample and color-convert + { + int k; + unsigned int i, j; + stbi_uc *output; + stbi_uc *coutput[4] = {NULL, NULL, NULL, NULL}; + + stbi__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); } - // can't error after this so, this is safe - output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - // now go ahead and resample - for (j=0; j < z->s->img_y; ++j) { - stbi_uc *out = output + n * z->s->img_x * j; - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, - y_bot ? r->line1 : r->line0, - y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) + r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) + r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) + r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) + r->resample = z->resample_row_hv_2_kernel; + else + r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = + r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; } - } - if (n >= 3) { - stbi_uc *y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i=0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } - } else { - if (is_rgb) { - if (n == 1) - for (i=0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i=0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc *y = coutput[0]; - if (n == 1) - for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; - else - for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; } - } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { + *out++ = y[i]; + *out++ = 255; + } + } } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } -} - -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - unsigned char* result; - stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x,y,comp,req_comp); - STBI_FREE(j); - return result; -} - -static int stbi__jpeg_test(stbi__context *s) -{ - int r; - stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; -} - -static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) -{ - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind( j->s ); - return 0; - } - if (x) *x = j->s->img_x; - if (y) *y = j->s->img_y; - if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; -} - -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) -{ - int result; - stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) + *comp = + z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + unsigned char *result; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) { + int r; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) { + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind(j->s); + return 0; + } + if (x) + *x = j->s->img_x; + if (y) + *y = j->s->img_y; + if (comp) + *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) { + int result; + stbi__jpeg *j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -3977,83 +4416,81 @@ static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct -{ - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[288]; - stbi__uint16 value[288]; +typedef struct { + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[288]; + stbi__uint16 value[288]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) -{ - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); +stbi_inline static int stbi__bitreverse16(int n) { + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); return n; } -stbi_inline static int stbi__bit_reverse(int v, int bits) -{ - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16-bits); -} - -static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) -{ - int i,k=0; - int code, next_code[16], sizes[17]; - - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i=0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i=1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i=1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16) code; - z->firstsymbol[i] = (stbi__uint16) k; - code = (code + sizes[i]); - if (sizes[i]) - if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); - z->maxcode[i] = code << (16-i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i=0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); - z->size [c] = (stbi_uc ) s; - z->value[c] = (stbi__uint16) i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s],s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } - } - ++next_code[s]; +stbi_inline static int stbi__bit_reverse(int v, int bits) { + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16 - bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, + int num) { + int i, k = 0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i = 0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i = 1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i = 1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16)code; + z->firstsymbol[i] = (stbi__uint16)k; + code = (code + sizes[i]); + if (sizes[i]) + if (code - 1 >= (1 << i)) + return stbi__err("bad codelengths", "Corrupt PNG"); + z->maxcode[i] = code << (16 - i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i = 0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); + z->size[c] = (stbi_uc)s; + z->value[c] = (stbi__uint16)i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s], s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } } - } - return 1; + ++next_code[s]; + } + } + return 1; } // zlib-from-memory implementation for PNG reading @@ -4062,277 +4499,313 @@ static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int // we require PNG read all the IDATs and combine them into a single // memory buffer -typedef struct -{ - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; +typedef struct { + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; - char *zout; - char *zout_start; - char *zout_end; - int z_expandable; + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; - stbi__zhuffman z_length, z_distance; + stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf *z) -{ - return (z->zbuffer >= z->zbuffer_end); -} - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) -{ - return stbi__zeof(z) ? 0 : *z->zbuffer++; -} - -static void stbi__fill_bits(stbi__zbuf *z) -{ - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); +stbi_inline static int stbi__zeof(stbi__zbuf *z) { + return (z->zbuffer >= z->zbuffer_end); } -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) -{ - unsigned int k; - if (z->num_bits < n) stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) { + return stbi__zeof(z) ? 0 : *z->zbuffer++; } -static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s,k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s=STBI__ZFAST_BITS+1; ; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) return -1; // invalid code! - // code size is s, so: - b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= sizeof (z->size)) return -1; // some data was corrupt somewhere! - if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes -{ - char *q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); - cur = (unsigned int) (z->zout - z->zout_start); - limit = old_limit = (unsigned) (z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; +static void stbi__fill_bits(stbi__zbuf *z) { + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) { + unsigned int k; + if (z->num_bits < n) + stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s, k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s = STBI__ZFAST_BITS + 1;; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) + return -1; // invalid code! + // code size is s, so: + b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= sizeof(z->size)) + return -1; // some data was corrupt somewhere! + if (z->size[b] != s) + return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, + int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) + return stbi__err("output buffer limit", "Corrupt PNG"); + cur = (unsigned int)(z->zout - z->zout_start); + limit = old_limit = (unsigned)(z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned)n) + return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if (limit > UINT_MAX / 2) + return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) + return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; } static const int stbi__zlength_base[31] = { - 3,4,5,6,7,8,9,10,11,13, - 15,17,19,23,27,31,35,43,51,59, - 67,83,99,115,131,163,195,227,258,0,0 }; - -static const int stbi__zlength_extra[31]= -{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; - -static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, -257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; - -static const int stbi__zdist_extra[32] = -{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; - -static int stbi__parse_huffman_block(stbi__zbuf *a) -{ - char *zout = a->zout; - for(;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) return 0; - zout = a->zout; - } - *zout++ = (char) z; + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + +static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, + 4, 4, 5, 5, 5, 5, 0, 0, 0}; + +static const int stbi__zdist_base[32] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; + +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) { + char *zout = a->zout; + for (;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) + return stbi__err("bad huffman code", + "Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) + return 0; + zout = a->zout; + } + *zout++ = (char)z; + } else { + stbi_uc *p; + int len, dist; + if (z == 256) { + a->zout = zout; + return 1; + } + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) + len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0) + return stbi__err("bad huffman code", "Corrupt PNG"); + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) + dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) + return stbi__err("bad dist", "Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) + return 0; + zout = a->zout; + } + p = (stbi_uc *)(zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { + do + *zout++ = v; + while (--len); + } } else { - stbi_uc *p; - int len,dist; - if (z == 256) { - a->zout = zout; - return 1; - } - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) return 0; - zout = a->zout; - } - p = (stbi_uc *) (zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { do *zout++ = v; while (--len); } - } else { - if (len) { do *zout++ = *p++; while (--len); } - } + if (len) { + do + *zout++ = *p++; + while (--len); + } } - } -} - -static int stbi__compute_huffman_codes(stbi__zbuf *a) -{ - static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286+32+137];//padding for maximum single op - stbi_uc codelength_sizes[19]; - int i,n; - - int hlit = stbi__zreceive(a,5) + 257; - int hdist = stbi__zreceive(a,5) + 1; - int hclen = stbi__zreceive(a,4) + 4; - int ntot = hlit + hdist; - - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i=0; i < hclen; ++i) { - int s = stbi__zreceive(a,3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; - - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc) c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a,2)+3; - if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n-1]; - } else if (c == 17) { - c = stbi__zreceive(a,3)+3; - } else if (c == 18) { - c = stbi__zreceive(a,7)+11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); - } - if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes+n, fill, c); - n += c; + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) { + static const stbi_uc length_dezigzag[19] = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc codelength_sizes[19]; + int i, n; + + int hlit = stbi__zreceive(a, 5) + 257; + int hdist = stbi__zreceive(a, 5) + 1; + int hclen = stbi__zreceive(a, 4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i = 0; i < hclen; ++i) { + int s = stbi__zreceive(a, 3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) + return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc)c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a, 2) + 3; + if (n == 0) + return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n - 1]; + } else if (c == 17) { + c = stbi__zreceive(a, 3) + 3; + } else if (c == 18) { + c = stbi__zreceive(a, 7) + 11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); } - } - if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; - return 1; -} - -static int stbi__parse_uncompressed_block(stbi__zbuf *a) -{ - stbi_uc header[4]; - int len,nlen,k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; -} - -static int stbi__parse_zlib_header(stbi__zbuf *a) -{ - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; -} - -static const stbi_uc stbi__zdefault_length[288] = -{ - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 -}; -static const stbi_uc stbi__zdefault_distance[32] = -{ - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 -}; + if (ntot - n < c) + return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes + n, fill, c); + n += c; + } + } + if (n != ntot) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) + return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) { + stbi_uc header[4]; + int len, nlen, k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = + (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) + return stbi__err("zlib corrupt", "Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) + return stbi__err("zlib corrupt", "Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) + return stbi__err("read past buffer", "Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) + return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) { + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if ((cmf * 256 + flg) % 31 != 0) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if (flg & 32) + return stbi__err("no preset dict", + "Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) + return stbi__err("bad compression", + "Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[288] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; +static const stbi_uc stbi__zdefault_distance[32] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; /* Init algorithm: { @@ -4346,117 +4819,131 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) -{ - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a,1); - type = stbi__zreceive(a,2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) return 0; - } else if (type == 3) { - return 0; +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) { + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) + return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a, 1); + type = stbi__zreceive(a, 2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) + return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, 288)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) + return 0; } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , 288)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; - } else { - if (!stbi__compute_huffman_codes(a)) return 0; - } - if (!stbi__parse_huffman_block(a)) return 0; + if (!stbi__compute_huffman_codes(a)) + return 0; } - } while (!final); - return 1; -} - -static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) -{ - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; - - return stbi__parse_zlib(a, parse_header); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) -{ - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int) (a.zout - a.zout_start); - else - return -1; -} - -STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(16384); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer+len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int) (a.zout - a.zout_start); - else - return -1; + if (!stbi__parse_huffman_block(a)) + return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, + int parse_header) { + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, + int *outlen) { + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + char const *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int)(a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, + int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(16384); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int)(a.zout - a.zout_start); + else + return -1; } #endif @@ -4471,1083 +4958,1312 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct -{ - stbi__uint32 length; - stbi__uint32 type; +typedef struct { + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) -{ - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) { + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context *s) -{ - static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; - int i; - for (i=0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) { + static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + int i; + for (i = 0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) + return stbi__err("bad png sig", "Not a PNG"); + return 1; } -typedef struct -{ - stbi__context *s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct { + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; - enum { - STBI__F_none=0, - STBI__F_sub=1, - STBI__F_up=2, - STBI__F_avg=3, - STBI__F_paeth=4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none = 0, + STBI__F_sub = 1, + STBI__F_up = 2, + STBI__F_avg = 3, + STBI__F_paeth = 4, + // synthetic filters used for first scanline to avoid needing a dummy row of + // 0s + STBI__F_avg_first, + STBI__F_paeth_first }; -static stbi_uc first_row_filter[5] = -{ - STBI__F_none, - STBI__F_sub, - STBI__F_none, - STBI__F_avg_first, - STBI__F_paeth_first -}; +static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, + STBI__F_avg_first, STBI__F_paeth_first}; -static int stbi__paeth(int a, int b, int c) -{ - int p = a + b - c; - int pa = abs(p-a); - int pb = abs(p-b); - int pc = abs(p-c); - if (pa <= pb && pa <= pc) return a; - if (pb <= pc) return b; - return c; +static int stbi__paeth(int a, int b, int c) { + int p = a + b - c; + int pa = abs(p - a); + int pb = abs(p - b); + int pc = abs(p - c); + if (pa <= pb && pa <= pc) + return a; + if (pb <= pc) + return b; + return c; } -static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; +static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, + 0, 0, 0, 0x01}; // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) -{ - int bytes = (depth == 16? 2 : 1); - stbi__context *s = a->s; - stbi__uint32 i,j,stride = x*out_n*bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later - - int output_bytes = out_n*bytes; - int filter_bytes = img_n*bytes; - int width = x; - - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); - a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) return stbi__err("outofmem", "Out of memory"); - - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; - - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); - - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *prior; - int filter = *raw++; - - if (filter > 4) - return stbi__err("invalid filter","Corrupt PNG"); - - if (depth < 8) { - if (img_width_bytes > x) return stbi__err("invalid width","Corrupt PNG"); - cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - - // if first row, use special filter that doesn't sample previous row - if (j == 0) filter = first_row_filter[filter]; - - // handle first byte explicitly - for (k=0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none : cur[k] = raw[k]; break; - case STBI__F_sub : cur[k] = raw[k]; break; - case STBI__F_up : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break; - case STBI__F_avg : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break; - case STBI__F_paeth : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break; - case STBI__F_avg_first : cur[k] = raw[k]; break; - case STBI__F_paeth_first: cur[k] = raw[k]; break; - } +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, + stbi__uint32 raw_len, int out_n, + stbi__uint32 x, stbi__uint32 y, int depth, + int color) { + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i, j, stride = x * out_n * bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n * bytes; + int filter_bytes = img_n * bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); + a->out = (stbi_uc *)stbi__malloc_mad3( + x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) + return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) + return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on + // non-interlaced PNGs, but issue #276 reported a PNG in the wild that had + // extra data at the end (all zeros), so just check for raw_len < img_len + // always. + if (raw_len < img_len) + return stbi__err("not enough pixels", "Corrupt PNG"); + + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter", "Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) + return stbi__err("invalid width", "Corrupt PNG"); + cur += + x * out_n - img_width_bytes; // store output to the rightmost img_len + // bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = + cur - + stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) + filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k = 0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none: + cur[k] = raw[k]; + break; + case STBI__F_sub: + cur[k] = raw[k]; + break; + case STBI__F_up: + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); + break; + case STBI__F_paeth: + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); + break; + case STBI__F_avg_first: + cur[k] = raw[k]; + break; + case STBI__F_paeth_first: + cur[k] = raw[k]; + break; } + } - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes+1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1)*filter_bytes; - #define STBI__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: memcpy(cur, raw, nk); break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break; - } - #undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n+1 == out_n); - #define STBI__CASE(f) \ - case f: \ - for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break; - } - #undef STBI__CASE - - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride*j; // start at the beginning of the row again - for (i=0; i < x; ++i,cur+=output_bytes) { - cur[filter_bytes+1] = 255; - } - } + // this is a little gross, so that we don't switch per-pixel or + // per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1) * filter_bytes; +#define STBI__CASE(f) \ + case f: \ + for (k = 0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: + memcpy(cur, raw, nk); + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - filter_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - filter_bytes], prior[k], + prior[k - filter_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); + } + break; } - } - - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *in = a->out + stride*j + x*out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit - // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k=x*img_n; k >= 2; k-=2, ++in) { - *cur++ = scale * ((*in >> 4) ); - *cur++ = scale * ((*in ) & 0x0f); - } - if (k > 0) *cur++ = scale * ((*in >> 4) ); - } else if (depth == 2) { - for (k=x*img_n; k >= 4; k-=4, ++in) { - *cur++ = scale * ((*in >> 6) ); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in ) & 0x03); - } - if (k > 0) *cur++ = scale * ((*in >> 6) ); - if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k=x*img_n; k >= 8; k-=8, ++in) { - *cur++ = scale * ((*in >> 7) ); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in ) & 0x01); - } - if (k > 0) *cur++ = scale * ((*in >> 7) ); - if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride*j; - if (img_n == 1) { - for (q=x-1; q >= 0; --q) { - cur[q*2+1] = 255; - cur[q*2+0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q=x-1; q >= 0; --q) { - cur[q*4+3] = 255; - cur[q*4+2] = cur[q*3+2]; - cur[q*4+1] = cur[q*3+1]; - cur[q*4+0] = cur[q*3+0]; - } - } - } +#undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n + 1 == out_n); +#define STBI__CASE(f) \ + case f: \ + for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, \ + cur += output_bytes, prior += output_bytes) \ + for (k = 0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { + cur[k] = raw[k]; + } + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - output_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - output_bytes], prior[k], + prior[k - output_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); + } + break; } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc *cur = a->out; - stbi__uint16 *cur16 = (stbi__uint16*)cur; - - for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) { - *cur16 = (cur[0] << 8) | cur[1]; +#undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride * j; // start at the beginning of the row again + for (i = 0; i < x; ++i, cur += output_bytes) { + cur[filter_bytes + 1] = 255; + } } - } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *in = a->out + stride * j + x * out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common + // 8-bit path optimal at minimal cost for 1/2/4-bit png guarante byte + // alignment, if width is not multiple of 8/4/2 we'll decode dummy + // trailing data that will be skipped in the later loop + stbi_uc scale = (color == 0) + ? stbi__depth_scale_table[depth] + : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than + // desired. we can allocate enough data that this never writes out of + // memory, but it could also overwrite the next scanline. can it overwrite + // non-empty data on the next scanline? yes, consider 1-pixel-wide + // scanlines with 1-bit-per-pixel. so we need to explicitly clamp the + // final ones + + if (depth == 4) { + for (k = x * img_n; k >= 2; k -= 2, ++in) { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } else if (depth == 2) { + for (k = x * img_n; k >= 4; k -= 4, ++in) { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k = x * img_n; k >= 8; k -= 8, ++in) { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) + *cur++ = scale * ((*in >> 7)); + if (k > 1) + *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) + *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) + *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) + *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) + *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) + *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride * j; + if (img_n == 1) { + for (q = x - 1; q >= 0; --q) { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q = x - 1; q >= 0; --q) { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } + } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16 *)cur; + + for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, + stbi__uint32 image_data_len, int out_n, + int depth, int color, int interlaced) { + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, + a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + for (p = 0; p < 7; ++p) { + int xorig[] = {0, 4, 0, 2, 0, 1, 0}; + int yorig[] = {0, 0, 4, 0, 2, 0, 1}; + int xspc[] = {8, 8, 4, 4, 2, 2, 1}; + int yspc[] = {8, 8, 8, 4, 4, 2, 2}; + int i, j, x, y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, + y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j = 0; j < y; ++j) { + for (i = 0; i < x; ++i) { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, + a->out + (j * x + i) * out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - return 1; + return 1; } -static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) -{ - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc *final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - - // de-interlacing - final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - for (p=0; p < 7; ++p) { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; - int i,j,x,y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; - } - for (j=0; j < y; ++j) { - for (i=0; i < x; ++i) { - int out_y = j*yspc[p]+yorig[p]; - int out_x = i*xspc[p]+xorig[p]; - memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, - a->out + (j*x+i)*out_bytes, out_bytes); - } - } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - return 1; -} + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); -static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); - - if (out_n == 2) { - for (i=0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i=0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 *p = (stbi__uint16*) z->out; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], + int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16 *)z->out; - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) -{ - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; - - p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) return stbi__err("outofmem", "Out of memory"); - - // between here and free(out) below, exitting would leak - temp_out = p; - - if (pal_img_n == 3) { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p += 3; - } - } else { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p[3] = palette[n+3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, + int pal_img_n) { + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); - STBI_NOTUSED(len); + // between here and free(out) below, exitting would leak + temp_out = p; - return 1; + if (pal_img_n == 3) { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p += 3; + } + } else { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p[3] = palette[n + 3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; } static int stbi__unpremultiply_on_load = 0; static int stbi__de_iphone_flag = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) -{ - stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) -{ - stbi__de_iphone_flag = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { + stbi__de_iphone_flag = flag_true_if_should_convert; } -static void stbi__de_iphone(stbi__png *z) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - if (s->img_out_n == 3) { // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 3; +static void stbi__de_iphone(stbi__png *z) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i = 0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = (t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i=0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = ( t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } + } else { + // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a, b, c, d) \ + (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + \ + (unsigned)(d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) { + stbi_uc palette[1024], pal_img_n = 0; + stbi_uc has_trans = 0, tc[3] = {0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; + int first = 1, k, interlace = 0, color = 0, is_iphone = 0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) + return 0; + + if (scan == STBI__SCAN_type) + return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { + int comp, filter; + if (!first) + return stbi__err("multiple IHDR", "Corrupt PNG"); + first = 0; + if (c.length != 13) + return stbi__err("bad IHDR len", "Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + z->depth = stbi__get8(s); + if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && + z->depth != 16) + return stbi__err("1/2/4/8/16-bit only", + "PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); + if (color > 6) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3 && z->depth == 16) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3) + pal_img_n = 3; + else if (color & 1) + return stbi__err("bad ctype", "Corrupt PNG"); + comp = stbi__get8(s); + if (comp) + return stbi__err("bad comp method", "Corrupt PNG"); + filter = stbi__get8(s); + if (filter) + return stbi__err("bad filter method", "Corrupt PNG"); + interlace = stbi__get8(s); + if (interlace > 1) + return stbi__err("bad interlace method", "Corrupt PNG"); + if (!s->img_x || !s->img_y) + return stbi__err("0-pixel image", "Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) + return stbi__err("too large", "Image too large to decode"); + if (scan == STBI__SCAN_header) + return 1; } else { - // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) + return stbi__err("too large", "Corrupt PNG"); + // if SCAN_header, have to scan to see if we have a tRNS } - } -} - -#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + break; + } -static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) -{ - stbi_uc palette[1024], pal_img_n=0; - stbi_uc has_trans=0, tc[3]={0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; - int first=1,k,interlace=0, color=0, is_iphone=0; - stbi__context *s = z->s; - - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; - - if (!stbi__check_png_header(s)) return 0; - - if (scan == STBI__SCAN_type) return 1; - - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C','g','B','I'): - is_iphone = 1; - stbi__skip(s, c.length); - break; - case STBI__PNG_TYPE('I','H','D','R'): { - int comp,filter; - if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); - first = 0; - if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); - s->img_x = stbi__get32be(s); - s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); - comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); - filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); - interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); - if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); - if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); - if (scan == STBI__SCAN_header) return 1; - } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); - // if SCAN_header, have to scan to see if we have a tRNS - } - break; - } - - case STBI__PNG_TYPE('P','L','T','E'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); - pal_len = c.length / 3; - if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); - for (i=0; i < pal_len; ++i) { - palette[i*4+0] = stbi__get8(s); - palette[i*4+1] = stbi__get8(s); - palette[i*4+2] = stbi__get8(s); - palette[i*4+3] = 255; - } - break; - } - - case STBI__PNG_TYPE('t','R','N','S'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); - if (pal_img_n) { - if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } - if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); - if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); - pal_img_n = 4; - for (i=0; i < c.length; ++i) - palette[i*4+3] = stbi__get8(s); - } else { - if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); - if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); - has_trans = 1; - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } - } - break; - } - - case STBI__PNG_TYPE('I','D','A','T'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); - if (scan == STBI__SCAN_header) { s->img_n = pal_img_n; return 1; } - if ((int)(ioff + c.length) < (int)ioff) return 0; - if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc *p; - if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); - z->idata = p; - } - if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); - ioff += c.length; - break; - } - - case STBI__PNG_TYPE('I','E','N','D'): { - stbi__uint32 raw_len, bpl; - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) return 1; - if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); - // initial guess for decoded data size to avoid unnecessary reallocs - bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component - raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); - if (z->expanded == NULL) return 0; // zlib should set error - STBI_FREE(z->idata); z->idata = NULL; - if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n+1; - else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; - if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; - } - } - if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); - if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; - } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; - } - STBI_FREE(z->expanded); z->expanded = NULL; - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - return 1; - } - - default: - // if critical, fail - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if ((c.type & (1 << 29)) == 0) { - #ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); - #endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); - } - stbi__skip(s, c.length); - break; + case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256 * 3) + return stbi__err("invalid PLTE", "Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) + return stbi__err("invalid PLTE", "Corrupt PNG"); + for (i = 0; i < pal_len; ++i) { + palette[i * 4 + 0] = stbi__get8(s); + palette[i * 4 + 1] = stbi__get8(s); + palette[i * 4 + 2] = stbi__get8(s); + palette[i * 4 + 3] = 255; } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } -} + break; + } -static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) -{ - void *result=NULL; - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) return result; + case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) + return stbi__err("tRNS after IDAT", "Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { + s->img_n = 4; + return 1; + } + if (pal_len == 0) + return stbi__err("tRNS before PLTE", "Corrupt PNG"); + if (c.length > pal_len) + return stbi__err("bad tRNS len", "Corrupt PNG"); + pal_img_n = 4; + for (i = 0; i < c.length; ++i) + palette[i * 4 + 3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) + return stbi__err("tRNS with alpha", "Corrupt PNG"); + if (c.length != (stbi__uint32)s->img_n * 2) + return stbi__err("bad tRNS len", "Corrupt PNG"); + has_trans = 1; + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * + stbi__depth_scale_table[z->depth]; // non 8-bit images will + // be larger + } } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) *n = p->s->img_n; - } - STBI_FREE(p->out); p->out = NULL; - STBI_FREE(p->expanded); p->expanded = NULL; - STBI_FREE(p->idata); p->idata = NULL; - - return result; -} - -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi__png p; - p.s = s; - return stbi__do_png(&p, x,y,comp,req_comp, ri); -} - -static int stbi__png_test(stbi__context *s) -{ - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; -} + break; + } -static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) -{ - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind( p->s ); - return 0; - } - if (x) *x = p->s->img_x; - if (y) *y = p->s->img_y; - if (comp) *comp = p->s->img_n; - return 1; -} + case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) + return stbi__err("no PLTE", "Corrupt PNG"); + if (scan == STBI__SCAN_header) { + s->img_n = pal_img_n; + return 1; + } + if ((int)(ioff + c.length) < (int)ioff) + return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) + idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, + idata_limit); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata + ioff, c.length)) + return stbi__err("outofdata", "Corrupt PNG"); + ioff += c.length; + break; + } -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); -} + case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + stbi__uint32 raw_len, bpl; + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) + return 1; + if (z->idata == NULL) + return stbi__err("no IDAT", "Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag( + (char *)z->idata, ioff, raw_len, (int *)&raw_len, !is_iphone); + if (z->expanded == NULL) + return 0; // zlib should set error + STBI_FREE(z->idata); + z->idata = NULL; + if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || + has_trans) + s->img_out_n = s->img_n + 1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, + z->depth, color, interlace)) + return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) + return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) + return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) + s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); + z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } -static int stbi__png_is16(stbi__context *s) -{ - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; + default: + // if critical, fail + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { +#ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); +#endif + return stbi__err(invalid_chunk, + "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, + stbi__result_info *ri) { + void *result = NULL; + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", + "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) + return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) + *n = p->s->img_n; + } + STBI_FREE(p->out); + p->out = NULL; + STBI_FREE(p->expanded); + p->expanded = NULL; + STBI_FREE(p->idata); + p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi__png p; + p.s = s; + return stbi__do_png(&p, x, y, comp, req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) { + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) { + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind(p->s); + return 0; + } + if (x) + *x = p->s->img_x; + if (y) + *y = p->s->img_y; + if (comp) + *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) { + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) { + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context *s) -{ - int r; - int sz; - if (stbi__get8(s) != 'B') return 0; - if (stbi__get8(s) != 'M') return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; -} - -static int stbi__bmp_test(stbi__context *s) -{ - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test_raw(stbi__context *s) { + int r; + int sz; + if (stbi__get8(s) != 'B') + return 0; + if (stbi__get8(s) != 'M') + return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) { + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } - // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) -{ - int n=0; - if (z == 0) return -1; - if (z >= 0x10000) { n += 16; z >>= 16; } - if (z >= 0x00100) { n += 8; z >>= 8; } - if (z >= 0x00010) { n += 4; z >>= 4; } - if (z >= 0x00004) { n += 2; z >>= 2; } - if (z >= 0x00002) { n += 1;/* >>= 1;*/ } - return n; +static int stbi__high_bit(unsigned int z) { + int n = 0; + if (z == 0) + return -1; + if (z >= 0x10000) { + n += 16; + z >>= 16; + } + if (z >= 0x00100) { + n += 8; + z >>= 8; + } + if (z >= 0x00010) { + n += 4; + z >>= 4; + } + if (z >= 0x00004) { + n += 2; + z >>= 2; + } + if (z >= 0x00002) { + n += 1; /* >>= 1;*/ + } + return n; } -static int stbi__bitcount(unsigned int a) -{ - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) { + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) -{ - static unsigned int mul_table[9] = { +static int stbi__shiftsigned(unsigned int v, int shift, int bits) { + static unsigned int mul_table[9] = { 0, - 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, - 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0,0,1,0,2,4,6,0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8-bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; -} - -typedef struct -{ - int bpp, offset, hsz; - unsigned int mr,mg,mb,ma, all_a; - int extra_read; + 0xff /*0b11111111*/, + 0x55 /*0b01010101*/, + 0x49 /*0b01001001*/, + 0x11 /*0b00010001*/, + 0x21 /*0b00100001*/, + 0x41 /*0b01000001*/, + 0x81 /*0b10000001*/, + 0x01 /*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0, 0, 1, 0, 2, 4, 6, 0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8 - bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct { + int bpp, offset, hsz; + unsigned int mr, mg, mb, ma, all_a; + int extra_read; } stbi__bmp_data; -static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) -{ - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; - - if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); - - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - int i; - if (hsz != 108 && hsz != 124) +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) { + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') + return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) + return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) + return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) + return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) + return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha + // channel but it was all 0 + } else { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - stbi__get32le(s); // discard color space - for (i=0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); } - } - return (void *) 1; -} - - -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - unsigned int mr=0,mg=0,mb=0,ma=0, all_a; - stbi_uc pal[256][4]; - int psize=0,i,j,width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set - - flip_vertically = ((int) s->img_y) > 0; - s->img_y = abs((int) s->img_y); - - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; - - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - STBI_ASSERT(info.offset == s->callback_already_read + (int) (s->img_buffer - s->img_buffer_original)); - if (info.offset != s->callback_already_read + (s->img_buffer - s->buffer_start)) { - return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + stbi__get32le(s); // discard color space + for (i = 0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved } - } - - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert - - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); - - out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z=0; - if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } - for (i=0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) stbi__get8(s); - pal[i][3] = 255; + } + } + return (void *)1; +} + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; + stbi_uc pal[256][4]; + int psize = 0, i, j, width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int)s->img_y) > 0; + s->img_y = abs((int)s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + STBI_ASSERT(info.offset == + s->callback_already_read + + (int)(s->img_buffer - s->img_buffer_original)); + if (info.offset != + s->callback_already_read + (s->img_buffer - s->buffer_start)) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z = 0; + if (psize == 0 || psize > 256) { + STBI_FREE(out); + return stbi__errpuc("invalid", "Corrupt BMP"); + } + for (i = 0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) + stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - + psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) + width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) + width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) + width = s->img_x; + else { + STBI_FREE(out); + return stbi__errpuc("bad bpp", "Corrupt BMP"); + } + pad = (-width) & 3; + if (info.bpp == 1) { + for (j = 0; j < (int)s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i = 0; i < (int)s->img_x; ++i) { + int color = (v >> bit_offset) & 0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + if ((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) width = s->img_x; - else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } - pad = (-width)&3; - if (info.bpp == 1) { - for (j=0; j < (int) s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i=0; i < (int) s->img_x; ++i) { - int color = (v>>bit_offset)&0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - if((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j=0; j < (int) s->img_y; ++j) { - for (i=0; i < (int) s->img_x; i += 2) { - int v=stbi__get8(s),v2=0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - } - stbi__skip(s, pad); - } + } else { + for (j = 0; j < (int)s->img_y; ++j) { + for (i = 0; i < (int)s->img_x; i += 2) { + int v = stbi__get8(s), v2 = 0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + } + stbi__skip(s, pad); } - } else { - int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; - int z = 0; - int easy=0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) width = 3 * s->img_x; - else if (info.bpp == 16) width = 2*s->img_x; - else /* bpp = 32 and pad = 0 */ width=0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; + } + } else { + int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, + bcount = 0, acount = 0; + int z = 0; + int easy = 0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) + width = 3 * s->img_x; + else if (info.bpp == 16) + width = 2 * s->img_x; + else /* bpp = 32 and pad = 0 */ + width = 0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - if (!easy) { - if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr) - 7; + rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg) - 7; + gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb) - 7; + bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma) - 7; + acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - for (j=0; j < (int) s->img_y; ++j) { - if (easy) { - for (i=0; i < (int) s->img_x; ++i) { - unsigned char a; - out[z+2] = stbi__get8(s); - out[z+1] = stbi__get8(s); - out[z+0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i=0; i < (int) s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) out[z++] = STBI__BYTECAST(a); - } - } - stbi__skip(s, pad); + } + for (j = 0; j < (int)s->img_y; ++j) { + if (easy) { + for (i = 0; i < (int)s->img_x; ++i) { + unsigned char a; + out[z + 2] = stbi__get8(s); + out[z + 1] = stbi__get8(s); + out[z + 0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) + out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i = 0; i < (int)s->img_x; ++i) { + stbi__uint32 v = + (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) + out[z++] = STBI__BYTECAST(a); + } } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j=0; j < (int) s->img_y>>1; ++j) { - stbi_uc *p1 = out + j *s->img_x*target; - stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; - for (i=0; i < (int) s->img_x*target; ++i) { - t = p1[i]; p1[i] = p2[i]; p2[i] = t; - } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j = 0; j < (int)s->img_y >> 1; ++j) { + stbi_uc *p1 = out + j * s->img_x * target; + stbi_uc *p2 = out + (s->img_y - 1 - j) * s->img_x * target; + for (i = 0; i < (int)s->img_x * target; ++i) { + t = p1[i]; + p1[i] = p2[i]; + p2[i] = t; } - } + } + } - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; - return out; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + return out; } #endif @@ -5555,592 +6271,625 @@ static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) -{ - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) *is_rgb16 = 0; - switch(bits_per_pixel) { - case 8: return STBI_grey; - case 16: if(is_grey) return STBI_grey_alpha; - // fallthrough - case 15: if(is_rgb16) *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: return bits_per_pixel/8; - default: return 0; - } -} - -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) -{ - int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; - int sz, tga_colormap_type; - stbi__get8(s); // discard Offset - tga_colormap_type = stbi__get8(s); // colormap type - if( tga_colormap_type > 1 ) { - stbi__rewind(s); - return 0; // only RGB or indexed allowed - } - tga_image_type = stbi__get8(s); // image type - if ( tga_colormap_type == 1 ) { // colormapped (paletted) image - if (tga_image_type != 1 && tga_image_type != 9) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip image x and y origin - tga_colormap_bpp = sz; - } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { - stbi__rewind(s); - return 0; // only RGB or grey allowed, +/- RLE - } - stbi__skip(s,9); // skip colormap specification and image x/y origin - tga_colormap_bpp = 0; - } - tga_w = stbi__get16le(s); - if( tga_w < 1 ) { - stbi__rewind(s); - return 0; // test width +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int *is_rgb16) { + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) + *is_rgb16 = 0; + switch (bits_per_pixel) { + case 8: + return STBI_grey; + case 16: + if (is_grey) + return STBI_grey_alpha; + // fallthrough + case 15: + if (is_rgb16) + *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: + return bits_per_pixel / 8; + default: + return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) { + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, + tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if (tga_colormap_type > 1) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; } - tga_h = stbi__get16le(s); - if( tga_h < 1 ) { - stbi__rewind(s); - return 0; // test height + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__rewind(s); + return 0; } - tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits - if (tga_colormap_bpp != 0) { - if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { - // when using a colormap, tga_bits_per_pixel is the size of the indexes - // I don't think anything but 8 or 16bit indexes makes sense - stbi__rewind(s); - return 0; - } - tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); - } else { - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + stbi__skip(s, 4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ((tga_image_type != 2) && (tga_image_type != 3) && + (tga_image_type != 10) && (tga_image_type != 11)) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE } - if(!tga_comp) { + stbi__skip(s, 9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if (tga_w < 1) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if (tga_h < 1) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); return 0; } - if (x) *x = tga_w; - if (y) *y = tga_h; - if (comp) *comp = tga_comp; - return 1; // seems to have passed everything -} - -static int stbi__tga_test(stbi__context *s) -{ - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if ( tga_color_type == 1 ) { // colormapped (paletted) image - if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - stbi__skip(s,4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s,9); // skip colormap specification and image x/y origin - } - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp( + tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), + NULL); + } + if (!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) + *x = tga_w; + if (y) + *y = tga_h; + if (comp) + *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) { + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if (tga_color_type > 1) + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image + if (sz != 1 && sz != 9) + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin + } + if (stbi__get16le(s) < 1) + goto errorEnd; // test width + if (stbi__get16le(s) < 1) + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) + goto errorEnd; // for colormapped images, bpp is size of an index + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead + // of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) -{ - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255)/31); - out[1] = (stbi_uc)((g * 255)/31); - out[2] = (stbi_uc)((b * 255)/31); - - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. -} - -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16=0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char *tga_data; - unsigned char *tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO - - if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // do a tiny bit of precessing - if ( tga_image_type >= 8 ) - { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); - - // If I'm paletted, then I'll use the number of bits from the palette - if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - - if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - - // tga info - *x = tga_width; - *y = tga_height; - if (comp) *comp = tga_comp; - - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); - - tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); - - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset ); - - if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { - for (i=0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height -i - 1 : i; - stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if ( tga_indexed) - { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start ); - // load the palette - tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); - } - if (tga_rgb16) { - stbi_uc *pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i=0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc *out) { + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be + // swapped later + out[0] = (stbi_uc)((r * 255) / 31); + out[1] = (stbi_uc)((g * 255) / 31); + out[2] = (stbi_uc)((b * 255) / 31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16 = 0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused + // (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // do a tiny bit of precessing + if (tga_image_type >= 8) { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if (tga_indexed) + tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), + &tga_rgb16); + + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have + // ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) + *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = + (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) + return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset); + + if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { + for (i = 0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height - i - 1 : i; + stbi_uc *tga_row = tga_data + row * tga_width * tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if (tga_indexed) { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // load the data - for (i=0; i < tga_width * tga_height; ++i) - { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if ( tga_is_RLE ) - { - if ( RLE_count == 0 ) - { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if ( !RLE_repeating ) - { - read_next_pixel = 1; - } - } else - { - read_next_pixel = 1; - } - // OK, if I need to read a pixel, do it now - if ( read_next_pixel ) - { - // load however much data we did have - if ( tga_indexed ) - { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if ( pal_idx >= tga_palette_len ) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx+j]; - } - } else if(tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel - - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i*tga_comp+j] = raw_data[j]; - // in case we're in RLE mode, keep counting down - --RLE_count; + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start); + // load the palette + tga_palette = + (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); } - // do I need to invert the image? - if ( tga_inverted ) - { - for (j = 0; j*2 < tga_height; ++j) - { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) - { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } - } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i = 0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // clear my palette, if I had one - if ( tga_palette != NULL ) - { - STBI_FREE( tga_palette ); + } + // load the data + for (i = 0; i < tga_width * tga_height; ++i) { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if (tga_is_RLE) { + if (RLE_count == 0) { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if (!RLE_repeating) { + read_next_pixel = 1; + } + } else { + read_next_pixel = 1; } - } + // OK, if I need to read a pixel, do it now + if (read_next_pixel) { + // load however much data we did have + if (tga_indexed) { + // read in index, then perform the lookup + int pal_idx = + (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if (pal_idx >= tga_palette_len) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx + j]; + } + } else if (tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) - { - unsigned char* tga_pixel = tga_data; - for (i=0; i < tga_width * tga_height; ++i) - { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i * tga_comp + j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if (tga_inverted) { + for (j = 0; j * 2 < tga_height; ++j) { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } } - } + } + // clear my palette, if I had one + if (tga_palette != NULL) { + STBI_FREE(tga_palette); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) { + unsigned char *tga_pixel = tga_data; + for (i = 0; i < tga_width * tga_height; ++i) { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, + tga_height); - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = - tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = + tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; } #endif // ************************************************************************************************* -// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, +// tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s) -{ - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; -} - -static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) -{ - int count, nleft, len; - - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; - } +static int stbi__psd_test(stbi__context *s) { + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) { + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) + return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; } - } - - return 1; -} + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) + return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w, h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", + "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", + "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for + // other modes.) + stbi__skip(s, stbi__get32be(s)); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s)); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s)); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", + "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *)stbi__malloc(4 * w * h); + + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w * h; + + // Initialize the data to zero. + // memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes + // literally. Else if n is between -127 and -1 inclusive, copy the next + // byte -n+1 times. Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row + // in the data, which we're going to just skip. + stbi__skip(s, h * channelCount * 2); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out + channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w,h; - stbi_uc *out; - STBI_NOTUSED(ri); - - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); - - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - - // Skip 6 reserved bytes. - stbi__skip(s, 6 ); - - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); - - if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s,stbi__get32be(s) ); - - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s) ); - - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s) ); - - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); - - // Create the destination image. - - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *) stbi__malloc(4 * w*h); - - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w*h; - - // Initialize the data to zero. - //memset( out, 0, pixelCount * 4 ); - - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop - - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2 ); - - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc *p; - - p = out+channel; - if (channel >= channelCount) { - // Fill this channel with default data. + } else { + // We're at the raw image data. It's each channel in order (Red, Green, + // Blue, Alpha, ...) where each channel consists of an 8-bit (or 16-bit) + // value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out + channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16)stbi__get16be(s); + } else { + stbi_uc *p = out + channel; + if (bitdepth == 16) { // input bpc for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); - } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } - } + *p = (stbi_uc)(stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } } - - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc *p = out+channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } - } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16) stbi__get16be(s); - } else { - stbi_uc *p = out+channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc) (stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } - } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i = 0; i < w * h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *)out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); + pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); + pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); + } } - } - - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i=0; i < w*h; ++i) { - stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); - pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); - pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); - } - } - } else { - for (i=0; i < w*h; ++i) { - unsigned char *pixel = out + 4*i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); - pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); - pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); - } - } + } else { + for (i = 0; i < w * h; ++i) { + unsigned char *pixel = out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); + pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); + pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); + } } - } - - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - - if (comp) *comp = 4; - *y = h; - *x = w; - - return out; + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, + w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + if (comp) + *comp = 4; + *y = h; + *x = w; + + return out; } #endif @@ -6152,215 +6901,222 @@ static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context *s,const char *str) -{ - int i; - for (i=0; i<4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s, const char *str) { + int i; + for (i = 0; i < 4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context *s) -{ - int i; +static int stbi__pic_test_core(stbi__context *s) { + int i; - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) + return 0; - for(i=0;i<84;++i) - stbi__get8(s); + for (i = 0; i < 84; ++i) + stbi__get8(s); - if (!stbi__pic_is4(s,"PICT")) - return 0; + if (!stbi__pic_is4(s, "PICT")) + return 0; - return 1; + return 1; } -typedef struct -{ - stbi_uc size,type,channel; +typedef struct { + stbi_uc size, type, channel; } stbi__pic_packet; -static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) -{ - int mask=0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) { + int mask = 0x80, i; - for (i=0; i<4; ++i, mask>>=1) { - if (channel & mask) { - if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); - dest[i]=stbi__get8(s); - } - } + for (i = 0; i < 4; ++i, mask >>= 1) { + if (channel & mask) { + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "PIC file too short"); + dest[i] = stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) -{ - int mask=0x80,i; +static void stbi__copyval(int channel, stbi_uc *dest, const stbi_uc *src) { + int mask = 0x80, i; - for (i=0;i<4; ++i, mask>>=1) - if (channel&mask) - dest[i]=src[i]; + for (i = 0; i < 4; ++i, mask >>= 1) + if (channel & mask) + dest[i] = src[i]; } -static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) -{ - int act_comp=0,num_packets=0,y,chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s, int width, int height, + int *comp, stbi_uc *result) { + int act_comp = 0, num_packets = 0, y, chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data - // for the same channel in multiple packets. - do { - stbi__pic_packet *packet; + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return stbi__errpuc("bad format","too many packets"); + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return stbi__errpuc("bad format", "too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); - if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (reading packets)"); + if (packet->size != 8) + return stbi__errpuc("bad format", "packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for(y=0; ytype) { - default: - return stbi__errpuc("bad format","packet has bad compression type"); + switch (packet->type) { + default: + return stbi__errpuc("bad format", "packet has bad compression type"); - case 0: {//uncompressed - int x; + case 0: { // uncompressed + int x; - for(x=0;xchannel,dest)) - return 0; - break; - } + for (x = 0; x < width; ++x, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + break; + } - case 1://Pure RLE - { - int left=width, i; - - while (left>0) { - stbi_uc count,value[4]; - - count=stbi__get8(s); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - - if (count > left) - count = (stbi_uc) left; - - if (!stbi__readval(s,packet->channel,value)) return 0; - - for(i=0; ichannel,dest,value); - left -= count; - } - } - break; - - case 2: {//Mixed RLE - int left=width; - while (left>0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); - - if (count >= 128) { // Repeated - stbi_uc value[4]; - - if (count==128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file","scanline overrun"); - - if (!stbi__readval(s,packet->channel,value)) - return 0; - - for(i=0;ichannel,dest,value); - } else { // Raw - ++count; - if (count>left) return stbi__errpuc("bad file","scanline overrun"); - - for(i=0;ichannel,dest)) - return 0; - } - left-=count; - } - break; - } - } + case 1: // Pure RLE + { + int left = width, i; + + while (left > 0) { + stbi_uc count, value[4]; + + count = stbi__get8(s); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pure read count)"); + + if (count > left) + count = (stbi_uc)left; + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + left -= count; + } + } break; + + case 2: { // Mixed RLE + int left = width; + while (left > 0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", + "file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count == 128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + } else { // Raw + ++count; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + for (i = 0; i < count; ++i, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + } + left -= count; + } + break; + } } - } + } + } - return result; + return result; } -static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) -{ - stbi_uc *result; - int i, x,y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s, int *px, int *py, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *result; + int i, x, y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) comp = &internal_comp; + if (!comp) + comp = &internal_comp; - for (i=0; i<92; ++i) - stbi__get8(s); + for (i = 0; i < 92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) + return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); //skip `ratio' - stbi__get16be(s); //skip `fields' - stbi__get16be(s); //skip `pad' + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); - memset(result, 0xff, x*y*4); + // intermediate buffer is RGBA + result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); + memset(result, 0xff, x * y * 4); - if (!stbi__pic_load_core(s,x,y,comp, result)) { - STBI_FREE(result); - result=0; - } - *px = x; - *py = y; - if (req_comp == 0) req_comp = *comp; - result=stbi__convert_format(result,4,req_comp,x,y); + if (!stbi__pic_load_core(s, x, y, comp, result)) { + STBI_FREE(result); + result = 0; + } + *px = x; + *py = y; + if (req_comp == 0) + req_comp = *comp; + result = stbi__convert_format(result, 4, req_comp, x, y); - return result; + return result; } -static int stbi__pic_test(stbi__context *s) -{ - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) { + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6368,514 +7124,539 @@ static int stbi__pic_test(stbi__context *s) // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct -{ - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct { + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct -{ - int w,h; - stbi_uc *out; // output buffer (always 4 components) - stbi_uc *background; // The current "background" as far as a gif is concerned - stbi_uc *history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc *color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct { + int w, h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context *s) -{ - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') return 0; - if (stbi__get8(s) != 'a') return 0; - return 1; -} - -static int stbi__gif_test(stbi__context *s) -{ - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; -} - -static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) -{ - int i; - for (i=0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } -} +static int stbi__gif_test_raw(stbi__context *s) { + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') + return 0; + if (stbi__get8(s) != 'a') + return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) { + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], + int num_entries, int transp) { + int i; + for (i = 0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, + int is_info) { + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') + return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') + return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + + if (comp != 0) + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the + // comments + + if (is_info) + return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) { + stbi__gif *g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind(s); + return 0; + } + if (x) + *x = g->w; + if (y) + *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) { + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) + return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) { + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) + return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc)init_code; + g->codes[init_code].suffix = (stbi_uc)init_code; + } + + // support no starting clear code + avail = clear + 2; + oldcode = -1; + + len = 0; + for (;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32)stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s, len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } -static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) -{ - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - version = stbi__get8(s); - if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + p->prefix = (stbi__int16)oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; + stbi__out_gif_code(g, (stbi__uint16)code); - if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } - if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image +// doesn't support it two back is the image from two frames ago, used for a very +// specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, + int req_comp, stbi_uc *two_back) { + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour + // (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp, 0)) + return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *)stbi__malloc(4 * pcount); + g->background = (stbi_uc *)stbi__malloc(4 * pcount); + g->history = (stbi_uc *)stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); - if (is_info) return 1; + // image is treated as "transparent" at the start - ie, nothing overwrites + // the current background; background colour is only used for pixels that + // are not rendered first frame, after that "background" color refers to the + // color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, + 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, + pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the + // old background + } - if (g->flags & 0x80) - stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } - return 1; -} + // background is what out is after the undoing of the previou frame; + memcpy(g->background, g->out, 4 * g->w * g->h); + } + + // clear my history; + memset(g->history, 0x00, + g->w * g->h); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; -static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind( s ); - return 0; - } - if (x) *x = g->w; - if (y) *y = g->h; - STBI_FREE(g); - return 1; -} + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; -static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) -{ - stbi_uc *p, *c; - int idx; - - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); - - if (g->cur_y >= g->max_y) return; - - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; - - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; - - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + g->lflags = stbi__get8(s); - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; } - } -} -static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) -{ - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw *p; - - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc) init_code; - g->codes[init_code].suffix = (stbi_uc) init_code; - } - - // support no starting clear code - avail = clear+2; - oldcode = -1; - - len = 0; - for(;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; - } - --len; - bits |= (stbi__int32) stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s,len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), + g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *)g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *)g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } + o = stbi__process_gif_raster(s, g); + if (!o) + return NULL; - p->prefix = (stbi__int16) oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = + 255; // just in case it was made transparent, undo that; It will + // be reset next frame if need be; + memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); + } + } + } - stbi__out_gif_code(g, (stbi__uint16) code); + return o; + } - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = + 10 * stbi__get16le( + s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } - } -} - -// this function is designed to support animated gifs, although stb_image doesn't support it -// two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) -{ - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); - - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *) stbi__malloc(4 * pcount); - g->background = (stbi_uc *) stbi__malloc(4 * pcount); - g->history = (stbi_uc *) stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); - - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; - - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); } + break; + } - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); - } - } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); - } - } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers - // background is what out is after the undoing of the previou frame; - memcpy( g->background, g->out, 4 * g->w * g->h ); - } - - // clear my history; - memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame - - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { - stbi__int32 x, y, w, h; - stbi_uc *o; - - x = stbi__get16le(s); - y = stbi__get16le(s); - w = stbi__get16le(s); - h = stbi__get16le(s); - if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); - - g->line_size = g->w * 4; - g->start_x = x * 4; - g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; - - // if the width of the specified rectangle is 0, that means - // we may not see *any* pixels or the image is malformed; - // to make sure this is caught, move the current y down to - // max_y (which is what out_gif_code checks). - if (w == 0) - g->cur_y = g->max_y; - - g->lflags = stbi__get8(s); - - if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; - } else { - g->step = g->line_size; - g->parse = 0; - } + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp) { + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } - if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *) g->lpal; - } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *) g->pal; - } else - return stbi__errpuc("missing color table", "Corrupt GIF"); - - o = stbi__process_gif_raster(s, g); - if (!o) return NULL; - - // if this was the first frame, - pcount = g->w * g->h; - if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); - } - } - } + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = + (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + if (NULL == tmp) { + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + return stbi__errpuc("outofmem", "Out of memory"); + } else { + out = (stbi_uc *)tmp; + out_size = layers * stride; + } + + if (delays) { + *delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, + sizeof(int) * layers); + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc *)stbi__malloc(layers * stride); + out_size = layers * stride; + if (delays) { + *delays = (int *)stbi__malloc(layers * sizeof(int)); + delays_size = layers * sizeof(int); + } + } + memcpy(out + ((layers - 1) * stride), u, stride); + if (layers >= 2) { + two_back = out - 2 * stride; + } - return o; - } - - case 0x21: // Comment Extension. - { - int len; - int ext = stbi__get8(s); - if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } - } - while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); - } - break; - } + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); - case 0x3B: // gif stream termination code - return (stbi_uc *) s; // using '1' causes warning on some compilers + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - default: - return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } -} + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc *u = 0; - stbi_uc *out = 0; - stbi_uc *two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; - - if (out) { - void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); - if (NULL == tmp) { - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); - return stbi__errpuc("outofmem", "Out of memory"); - } - else { - out = (stbi_uc*) tmp; - out_size = layers * stride; - } - - if (delays) { - *delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc*)stbi__malloc( layers * stride ); - out_size = layers * stride; - if (delays) { - *delays = (int*) stbi__malloc( layers * sizeof(int) ); - delays_size = layers * sizeof(int); - } - } - memcpy( out + ((layers - 1) * stride), u, stride ); - if (layers >= 2) { - two_back = out - 2 * stride; - } +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); - if (delays) { - (*delays)[layers - 1U] = g.delay; - } - } - } while (u != 0); + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + return u; } -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); - - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } - - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); - - return u; -} - -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) -{ - return stbi__gif_info_raw(s,x,y,comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) { + return stbi__gif_info_raw(s, x, y, comp); } #endif @@ -6883,396 +7664,434 @@ static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context *s, const char *signature) -{ - int i; - for (i=0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; -} - -static int stbi__hdr_test(stbi__context* s) -{ - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if(!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; -} - -#define STBI__HDR_BUFLEN 1024 -static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) -{ - int len=0; - char c = '\0'; - - c = (char) stbi__get8(z); - - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN-1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) { + int i; + for (i = 0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context *s) { + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if (!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) { + int len = 0; + char c = '\0'; + + c = (char)stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN - 1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char)stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) { + if (input[3] != 0) { + float f1; + // Exponent + f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) + output[1] = 1; + if (req_comp == 4) + output[3] = 1; + } else { + switch (req_comp) { + case 4: + output[3] = 1; /* fallthrough */ + case 3: + output[0] = output[1] = output[2] = 0; + break; + case 2: + output[1] = 1; /* fallthrough */ + case 1: + output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1, c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s, buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && + strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) + return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int)strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) + *comp = 3; + if (req_comp == 0) + req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = + (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if (width < 8 || width >= 32768) { + // Read flat data + for (j = 0; j < height; ++j) { + for (i = 0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, + req_comp); } - c = (char) stbi__get8(z); - } - - buffer[len] = 0; - return buffer; -} + } + } else { + // Read RLE-encoded data + scanline = NULL; -static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) -{ - if ( input[3] != 0 ) { - float f1; - // Exponent - f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a + // decoded pixel (note this can't be a valid pixel--one of RGB must be + // >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc)c1; + rgbe[1] = (stbi_uc)c2; + rgbe[2] = (stbi_uc)len; + rgbe[3] = (stbi_uc)stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense } - if (req_comp == 2) output[1] = 1; - if (req_comp == 4) output[3] = 1; - } else { - switch (req_comp) { - case 4: output[3] = 1; /* fallthrough */ - case 3: output[0] = output[1] = output[2] = 0; - break; - case 2: output[1] = 1; /* fallthrough */ - case 1: output[0] = 0; - break; + len <<= 8; + len |= stbi__get8(s); + if (len != width) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - } -} - -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int width, height; - stbi_uc *scanline; - float *hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1,c2, z; - const char *headerToken; - STBI_NOTUSED(ri); - - // Check identifier - headerToken = stbi__hdr_gettoken(s,buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); - - // Parse header - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); - - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int) strtol(token, NULL, 10); - - if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - - *x = width; - *y = height; - - if (comp) *comp = 3; - if (req_comp == 0) req_comp = 3; - - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); - - // Read data - hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); - - // Load image data - // image data is stored as some number of sca - if ( width < 8 || width >= 32768) { - // Read flat data - for (j=0; j < height; ++j) { - for (i=0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } + if (scanline == NULL) { + scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } } - } else { - // Read RLE-encoded data - scanline = NULL; - - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc) c1; - rgbe[1] = (stbi_uc) c2; - rgbe[2] = (stbi_uc) len; - rgbe[3] = (stbi_uc) stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - if (scanline == NULL) { - scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - for (i=0; i < width; ++i) - stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } } - if (scanline) - STBI_FREE(scanline); - } - - return hdr_data; -} - -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int dummy; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (stbi__hdr_test(s) == 0) { - stbi__rewind( s ); - return 0; - } - - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) { - stbi__rewind( s ); - return 0; - } - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *y = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *x = (int) strtol(token, NULL, 10); - *comp = 3; - return 1; + for (i = 0; i < width; ++i) + stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, + scanline + i * 4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind(s); + return 0; + } + + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) { + stbi__rewind(s); + return 0; + } + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *y = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *x = (int)strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) -{ - void *p; - stbi__bmp_data info; - - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - stbi__rewind( s ); - if (p == NULL) - return 0; - if (x) *x = s->img_x; - if (y) *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) { + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + stbi__rewind(s); + if (p == NULL) + return 0; + if (x) + *x = s->img_x; + if (y) + *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) -{ - int channelCount, dummy, depth; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind( s ); - return 0; - } - *comp = 4; - return 1; -} - -static int stbi__psd_is16(stbi__context *s) -{ - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - (void) stbi__get32be(s); - (void) stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) { + int channelCount, dummy, depth; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind(s); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) { + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + (void)stbi__get32be(s); + (void)stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind(s); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) -{ - int act_comp=0,num_packets=0,chained,dummy; - stbi__pic_packet packets[10]; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { - stbi__rewind(s); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) { + int act_comp = 0, num_packets = 0, chained, dummy; + stbi__pic_packet packets[10]; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) return 0; - } - stbi__skip(s, 88); + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind( s); + if (stbi__at_eof(s)) { + stbi__rewind(s); return 0; - } - if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind( s ); + } + if (packet->size != 8) { + stbi__rewind(s); return 0; - } - - stbi__skip(s, 8); - - do { - stbi__pic_packet *packet; - - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return 0; - - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; - - if (stbi__at_eof(s)) { - stbi__rewind( s ); - return 0; - } - if (packet->size != 8) { - stbi__rewind( s ); - return 0; - } - } while (chained); + } + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); + *comp = (act_comp & 0x10 ? 4 : 3); - return 1; + return 1; } #endif @@ -7290,257 +8109,266 @@ static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s) -{ - char p, t; - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) { + char p, t; + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + return 1; } -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - STBI_NOTUSED(ri); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + STBI_NOTUSED(ri); - if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) - return 0; + if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) + return 0; - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; - if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "PNM too large"); + if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "PNM too large"); - out = (stbi_uc *) stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - stbi__getn(s, out, s->img_n * s->img_x * s->img_y); + out = (stbi_uc *)stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + stbi__getn(s, out, s->img_n * s->img_x * s->img_y); - if (req_comp && req_comp != s->img_n) { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - return out; + if (req_comp && req_comp != s->img_n) { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + return out; } -static int stbi__pnm_isspace(char c) -{ - return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +static int stbi__pnm_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; } -static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) -{ - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char) stbi__get8(s); +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) { + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char)stbi__get8(s); - if (stbi__at_eof(s) || *c != '#') - break; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') + *c = (char)stbi__get8(s); + } } -static int stbi__pnm_isdigit(char c) -{ - return c >= '0' && c <= '9'; +static int stbi__pnm_isdigit(char c) { + return c >= '0' && c <= '9'; } -static int stbi__pnm_getinteger(stbi__context *s, char *c) -{ - int value = 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) { + int value = 0; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value*10 + (*c - '0'); - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value * 10 + (*c - '0'); + *c = (char)stbi__get8(s); + } - return value; + return value; } -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) -{ - int maxv, dummy; - char c, p, t; +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) { + int maxv, dummy; + char c, p, t; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; - stbi__rewind(s); + stbi__rewind(s); - // Get identifier - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } + // Get identifier + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = + (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - c = (char) stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); + c = (char)stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width - stbi__pnm_skip_whitespace(s, &c); + *x = stbi__pnm_getinteger(s, &c); // read width + stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height - stbi__pnm_skip_whitespace(s, &c); + *y = stbi__pnm_getinteger(s, &c); // read height + stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value + maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 255) - return stbi__err("max value > 255", "PPM image not 8-bit"); - else - return 1; + if (maxv > 255) + return stbi__err("max value > 255", "PPM image not 8-bit"); + else + return 1; } #endif -static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) -{ - #ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) return 1; - #endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) { +#ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) + return 1; +#endif - // test tga last because it's a crappy test! - #ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; - #endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +// test tga last because it's a crappy test! +#ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; +#endif + return stbi__err("unknown image type", + "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context *s) -{ - #ifndef STBI_NO_PNG - if (stbi__png_is16(s)) return 1; - #endif +static int stbi__is_16_main(stbi__context *s) { +#ifndef STBI_NO_PNG + if (stbi__png_is16(s)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) + return 1; +#endif - return 0; + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_info_from_file(f, x, y, comp); - fclose(f); - return result; -} - -STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s,x,y,comp); - fseek(f,pos,SEEK_SET); - return r; -} - -STBIDEF int stbi_is_16_bit(char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_is_16_bit_from_file(f); - fclose(f); - return result; -} - -STBIDEF int stbi_is_16_bit_from_file(FILE *f) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f,pos,SEEK_SET); - return r; +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s, x, y, comp); + fseek(f, pos, SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f, pos, SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, + int *x, int *y, int *comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, + void *user) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION /* revision history: - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs - 2.19 (2018-02-11) fix warning - 2.18 (2018-01-30) fix warnings - 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and + platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix + warnings 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug 1-bit BMP *_is_16_bit api avoid warnings @@ -7555,13 +8383,11 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user warning fixes; disable run-time SSE detection on gcc; uniform handling of optional "return" values; thread-safe initialization of zlib tables - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) allocate large structures on the stack - remove white matting for transparent PSD - fix reported channel count for PNG & BMP - re-enable SSE2 in non-gcc 64-bit + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet + JPGs 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now 2.12 + (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 (2016-04-02) + allocate large structures on the stack remove white matting for transparent + PSD fix reported channel count for PNG & BMP re-enable SSE2 in non-gcc 64-bit support RGB-formatted JPEG read 16-bit PNGs (only as 8-bit) 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED @@ -7569,11 +8395,9 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 16-bit-per-pixel TGA (not bit-per-component) info() for TGA could break due to .hdr handling info() for BMP to shares code instead of sloppy parse - can use STBI_REALLOC_SIZED if allocator doesn't support realloc - code cleanup - 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA - 2.07 (2015-09-13) fix compiler warnings - partial animated GIF support + can use STBI_REALLOC_SIZED if allocator doesn't support + realloc code cleanup 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD + as RGBA 2.07 (2015-09-13) fix compiler warnings partial animated GIF support limited 16-bpc PSD support #ifdef unused functions bug with < 92 byte PIC,PNM,HDR,TGA @@ -7584,23 +8408,18 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user stbi_set_flip_vertically_on_load (nguillemot) fix NEON support; fix mingw support 2.02 (2015-01-19) fix incorrect assert, fix warning - 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 - 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG - 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) - progressive JPEG (stb) - PGM/PPM support (Ken Miller) - STBI_MALLOC,STBI_REALLOC,STBI_FREE + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit + without -msse2 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG 2.00 + (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) progressive + JPEG (stb) PGM/PPM support (Ken Miller) STBI_MALLOC,STBI_REALLOC,STBI_FREE GIF bugfix -- seemingly never worked STBI_NO_*, STBI_ONLY_* 1.48 (2014-12-14) fix incorrectly-named assert() - 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) - optimize PNG (ryg) - fix bug in interlaced PNG with user-specified channel count (stb) - 1.46 (2014-08-26) - fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG - 1.45 (2014-08-16) - fix MSVC-ARM internal compiler error by wrapping malloc - 1.44 (2014-08-07) + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar + Cornut & stb) optimize PNG (ryg) fix bug in interlaced PNG with + user-specified channel count (stb) 1.46 (2014-08-26) fix broken tRNS chunk + (colorkey-style transparency) in non-paletted PNG 1.45 (2014-08-16) fix + MSVC-ARM internal compiler error by wrapping malloc 1.44 (2014-08-07) various warning fixes from Ronny Chevalier 1.43 (2014-07-15) fix MSVC-only compiler problem in code changed in 1.42 @@ -7609,73 +8428,48 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user fixes to stbi__cleanup_jpeg path added STBI_ASSERT to avoid requiring assert.h 1.41 (2014-06-25) - fix search&replace from 1.36 that messed up comments/error messages - 1.40 (2014-06-22) - fix gcc struct-initialization warning - 1.39 (2014-06-15) - fix to TGA optimization when req_comp != number of components in TGA; - fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) - add support for BMP version 5 (more ignored fields) - 1.38 (2014-06-06) - suppress MSVC warnings on integer casts truncating values - fix accidental rename of 'skip' field of I/O - 1.37 (2014-06-04) - remove duplicate typedef - 1.36 (2014-06-03) - convert to header file single-file library - if de-iphone isn't set, load iphone images color-swapped instead of returning NULL - 1.35 (2014-05-27) - various warnings - fix broken STBI_SIMD path - fix bug where stbi_load_from_file no longer left file pointer in correct place - fix broken non-easy path for 32-bit BMP (possibly never used) - TGA optimization by Arseny Kapoulkine - 1.34 (unknown) - use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case - 1.33 (2011-07-14) - make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements - 1.32 (2011-07-13) - support for "info" function for all supported filetypes (SpartanJ) - 1.31 (2011-06-20) - a few more leak fixes, bug in PNG handling (SpartanJ) - 1.30 (2011-06-11) - added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + fix search&replace from 1.36 that messed up comments/error + messages 1.40 (2014-06-22) fix gcc struct-initialization warning 1.39 + (2014-06-15) fix to TGA optimization when req_comp != number of components in + TGA; fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my + test suite) add support for BMP version 5 (more ignored fields) 1.38 + (2014-06-06) suppress MSVC warnings on integer casts truncating values fix + accidental rename of 'skip' field of I/O 1.37 (2014-06-04) remove duplicate + typedef 1.36 (2014-06-03) convert to header file single-file library if + de-iphone isn't set, load iphone images color-swapped instead of returning + NULL 1.35 (2014-05-27) various warnings fix broken STBI_SIMD path fix bug + where stbi_load_from_file no longer left file pointer in correct place fix + broken non-easy path for 32-bit BMP (possibly never used) TGA optimization by + Arseny Kapoulkine 1.34 (unknown) use STBI_NOTUSED in + stbi__resample_row_generic(), fix one more leak in tga failure case 1.33 + (2011-07-14) make stbi_is_hdr work in STBI_NO_HDR (as specified), minor + compiler-friendly improvements 1.32 (2011-07-13) support for "info" function + for all supported filetypes (SpartanJ) 1.31 (2011-06-20) a few more leak + fixes, bug in PNG handling (SpartanJ) 1.30 (2011-06-11) added ability to + load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway - error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) - fix inefficiency in decoding 32-bit BMP (David Woo) - 1.29 (2010-08-16) - various warning fixes from Aurelien Pocheville - 1.28 (2010-08-01) - fix bug in GIF palette transparency (SpartanJ) - 1.27 (2010-08-01) - cast-to-stbi_uc to fix warnings - 1.26 (2010-07-24) - fix bug in file buffering for PNG reported by SpartanJ - 1.25 (2010-07-17) - refix trans_data warning (Won Chun) - 1.24 (2010-07-12) - perf improvements reading from files on platforms with lock-heavy fgetc() - minor perf improvements for jpeg - deprecated type-specific functions so we'll get feedback if they're needed - attempt to fix trans_data warning (Won Chun) - 1.23 fixed bug in iPhone support - 1.22 (2010-07-10) - removed image *writing* support - stbi_info support from Jetro Lauha - GIF support from Jean-Marc Lienher + removed support for installable file formats (stbi_loader) -- + would have been broken for IO callbacks anyway error cases in bmp and tga + give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in + decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from + Aurelien Pocheville 1.28 (2010-08-01) fix bug in GIF palette transparency + (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings 1.26 + (2010-07-24) fix bug in file buffering for PNG reported by SpartanJ 1.25 + (2010-07-17) refix trans_data warning (Won Chun) 1.24 (2010-07-12) perf + improvements reading from files on platforms with lock-heavy fgetc() minor + perf improvements for jpeg deprecated type-specific functions so we'll get + feedback if they're needed attempt to fix trans_data warning (Won Chun) 1.23 + fixed bug in iPhone support 1.22 (2010-07-10) removed image *writing* + support stbi_info support from Jetro Lauha GIF support from Jean-Marc Lienher iPhone PNG-extensions from James Brown - warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) - 1.21 fix use of 'stbi_uc' in header (reported by jon blow) - 1.20 added support for Softimage PIC, by Tom Seddon - 1.19 bug in interlaced PNG corruption check (found by ryg) - 1.18 (2008-08-02) - fix a threading bug (local mutable static) - 1.17 support interlaced PNG - 1.16 major bugfix - stbi__convert_format converted one too many pixels - 1.15 initialize some fields for thread safety - 1.14 fix threadsafe conversion bug - header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. + Janez (U+017D)emva) 1.21 fix use of 'stbi_uc' in header (reported by jon + blow) 1.20 added support for Softimage PIC, by Tom Seddon 1.19 bug in + interlaced PNG corruption check (found by ryg) 1.18 (2008-08-02) fix a + threading bug (local mutable static) 1.17 support interlaced PNG 1.16 + major bugfix - stbi__convert_format converted one too many pixels 1.15 + initialize some fields for thread safety 1.14 fix threadsafe conversion + bug header-file-only version (#define STBI_HEADER_FILE_ONLY before including) 1.13 threadsafe 1.12 const qualifiers in the API 1.11 Support installable IDCT, colorspace conversion routines @@ -7685,15 +8479,14 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz 1.07 attempt to fix C++ warning/errors again 1.06 attempt to fix C++ warning/errors again - 1.05 fix TGA loading to return correct *comp and use good luminance calc - 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free - 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR - 1.02 support for (subset of) HDR files, float interface for preferred access to them - 1.01 fix bug: possible bug in handling right-side up bmps... not sure - fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all - 1.00 interface to zlib that skips zlib header - 0.99 correct handling of alpha in palette - 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 1.05 fix TGA loading to return correct *comp and use good luminance + calc 1.04 default float alpha is 1, not 255; use 'void *' for + stbi_image_free 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR 1.02 support + for (subset of) HDR files, float interface for preferred access to them 1.01 + fix bug: possible bug in handling right-side up bmps... not sure fix bug: the + stbi__bmp_load() and stbi__tga_load() functions didn't work at all 1.00 + interface to zlib that skips zlib header 0.99 correct handling of alpha in + palette 0.98 TGA loader by lonesock; dynamically add loaders (untested) 0.97 jpeg errors on too large a file; also catch another malloc failure 0.96 fix detection of invalid v value - particleman@mollyrocket forum 0.95 during header scan, seek to markers in case of padding @@ -7706,8 +8499,8 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 0.60 fix compiling as c++ 0.59 fix warnings: merge Dave Moore's -Wall fixes 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian - 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available - 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but + less than 16 available 0.56 fix bug: zlib uncompressed mode len vs. nlen 0.55 fix bug: restart_interval not initialized to 0 0.54 allow NULL for 'int *comp' 0.53 fix bug in png 3->4; speedup png decoding @@ -7718,7 +8511,6 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user first released version */ - /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. diff --git a/sample_programs/ml_sample_programs/vision_models/dave_keras/train.py b/sample_programs/ml_sample_programs/vision_models/dave_keras/train.py index 81110853..f82b551a 100644 --- a/sample_programs/ml_sample_programs/vision_models/dave_keras/train.py +++ b/sample_programs/ml_sample_programs/vision_models/dave_keras/train.py @@ -6,9 +6,8 @@ @author: berk """ - -from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Conv2D,Dense,Flatten,Dropout +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Conv2D, Dense, Flatten, Dropout import tensorflow as tf import scipy.misc import os @@ -19,112 +18,108 @@ input_shape = (66, 200, 3) - -#------------------------------------------------------ -#This section for reduce some errors related to GPU allocation on my system. -#it may not neccesary for yours. If it is, removing this part may increase the performance. -#from tensorflow import Session,ConfigProto -#from keras.backend.tensorflow_backend import set_session -#config = ConfigProto() -#config.gpu_options.per_process_gpu_memory_fraction = 0.3 -#set_session(Session(config=config)) -#-------------------------------------------------------- - +# ------------------------------------------------------ +# This section for reduce some errors related to GPU allocation on my system. +# it may not neccesary for yours. If it is, removing this part may increase the performance. +# from tensorflow import Session,ConfigProto +# from keras.backend.tensorflow_backend import set_session +# config = ConfigProto() +# config.gpu_options.per_process_gpu_memory_fraction = 0.3 +# set_session(Session(config=config)) +# -------------------------------------------------------- def atan_layer(x): return tf.multiply(tf.atan(x), 2) + def atan_layer_shape(input_shape): return input_shape -#our model is Nvidia Dave-2 where you can find here: https://arxiv.org/pdf/1604.07316.pdf +# our model is Nvidia Dave-2 where you can find here: https://arxiv.org/pdf/1604.07316.pdf def defineModel(): model = Sequential() # 5x5 Convolutional layers with stride of 2x2 - model.add(Conv2D(24, (5, 5), strides=(2, 2),activation='elu',input_shape=input_shape)) - model.add(Conv2D(36, (5, 5), strides=(2, 2),activation='elu')) - model.add(Conv2D(48, (5, 5), strides=(2, 2),activation='elu')) - + model.add( + Conv2D(24, (5, 5), strides=(2, 2), activation="elu", input_shape=input_shape) + ) + model.add(Conv2D(36, (5, 5), strides=(2, 2), activation="elu")) + model.add(Conv2D(48, (5, 5), strides=(2, 2), activation="elu")) + # 3x3 Convolutional layers with stride of 1x1 - model.add(Conv2D(64, (3, 3),activation='elu')) - model.add(Conv2D(64, (3, 3),activation='elu')) - + model.add(Conv2D(64, (3, 3), activation="elu")) + model.add(Conv2D(64, (3, 3), activation="elu")) + # Flatten before passing to the fully connected layers model.add(Flatten()) # Three fully connected layers - model.add(Dense(100,activation='elu')) - model.add(Dropout(.25)) - model.add(Dense(50,activation='elu')) - model.add(Dropout(.25)) - model.add(Dense(10,activation='elu')) - model.add(Dropout(.25)) - - # Output layer with linear activation - model.add(Dense(1,activation="linear")) - - return model + model.add(Dense(100, activation="elu")) + model.add(Dropout(0.25)) + model.add(Dense(50, activation="elu")) + model.add(Dropout(0.25)) + model.add(Dense(10, activation="elu")) + model.add(Dropout(0.25)) + # Output layer with linear activation + model.add(Dense(1, activation="linear")) - + return model -angle=[] -f= open("data.txt") #read steering angles from disk and preprocess +angle = [] + +f = open("data.txt") # read steering angles from disk and preprocess data = f.read() data = data.split() -for i in data: #if the node end with ".jpg" ignore it. what we need is only angles - if i[-1]=='g': +for i in data: # if the node end with ".jpg" ignore it. what we need is only angles + if i[-1] == "g": pass else: angle.append(float(i)) -paths= list(paths.list_images(os.getcwd()+"/data")) #collect the image ids (name) to send datagen.py -ids=[] +paths = list( + paths.list_images(os.getcwd() + "/data") +) # collect the image ids (name) to send datagen.py +ids = [] for i in paths: - name = i.split(os.path.sep)[-1] + name = i.split(os.path.sep)[-1] name = name[:-4] - ids.append(int(name)) -ids.sort() - + ids.append(int(name)) +ids.sort() -#train and test set ratio. We create two dictionary for data batch generator. -#partition consist of two list that holds the train and validation image ids. labels hold the angles.. -partition={'train':ids[:int(len(ids)*.8)],'validation':ids[-int(len(ids)*.2):]} -labels={} +# train and test set ratio. We create two dictionary for data batch generator. +# partition consist of two list that holds the train and validation image ids. labels hold the angles.. +partition = { + "train": ids[: int(len(ids) * 0.8)], + "validation": ids[-int(len(ids) * 0.2) :], +} +labels = {} for i in partition["train"]: - labels[i]=float(angle[i])* scipy.pi / 180 + labels[i] = float(angle[i]) * scipy.pi / 180 for i in partition["validation"]: - labels[i]=float(angle[i])* scipy.pi / 180 - - + labels[i] = float(angle[i]) * scipy.pi / 180 - # Parameters for datagen.py -params = {'dim': (66,200,3), - 'batch_size': 2, - 'shuffle': True} +params = {"dim": (66, 200, 3), "batch_size": 2, "shuffle": True} # Generators training_generator = DataGenerator(partition["train"], labels, **params) validation_generator = DataGenerator(partition["validation"], labels, **params) -#defining our model and compile with adam optimizer and mean squere error. -model=defineModel() -model.compile(optimizer='adam', loss="mse") - -#train it for 10 epochs -model.fit_generator(generator=training_generator, - epochs=10, - validation_data=validation_generator) - -#save trained model. -model.save("model.h5") +# defining our model and compile with adam optimizer and mean squere error. +model = defineModel() +model.compile(optimizer="adam", loss="mse") +# train it for 10 epochs +model.fit_generator( + generator=training_generator, epochs=10, validation_data=validation_generator +) +# save trained model. +model.save("model.h5") diff --git a/sample_programs/ml_sample_programs/vision_models/googlenet-12/compile.sh b/sample_programs/ml_sample_programs/vision_models/googlenet-12/compile.sh index 2278f0fb..598822d2 100755 --- a/sample_programs/ml_sample_programs/vision_models/googlenet-12/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/googlenet-12/compile.sh @@ -9,7 +9,7 @@ fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp googlenet-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp googlenet-12.onnx mlir-translate -mlir-to-llvmir googlenet-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/inception-v1-12/compile.sh b/sample_programs/ml_sample_programs/vision_models/inception-v1-12/compile.sh index 493c7225..4274a3c1 100755 --- a/sample_programs/ml_sample_programs/vision_models/inception-v1-12/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/inception-v1-12/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp inception-v1-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp inception-v1-12.onnx mlir-translate -mlir-to-llvmir inception-v1-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/compile.sh b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/compile.sh index d5c73d2c..05174cd5 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/compile.sh @@ -4,7 +4,7 @@ python3 lenet-fmnist.py printf "\n[Compile Script]: Convert TF model to LLVM IR\n" python3 -m tf2onnx.convert --saved-model lenet-fmnist.tf --output model.onnx python3 ../../../../tools/ExtendONNXModel.py --model_path ./model.onnx --output_model_path ./extendedmodel.onnx > expected_op_seq.txt -onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp +onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp mlir-translate -mlir-to-llvmir extendedmodel.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/image.c b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/image.c index aac7662f..9f1c5292 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/image.c +++ b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/image.c @@ -157,7 +157,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/lenet-fmnist.py b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/lenet-fmnist.py index 78fd5ab1..2b6f7605 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/lenet-fmnist.py +++ b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/lenet-fmnist.py @@ -2,31 +2,45 @@ def process_images(images): - images = images.reshape((-1, 28, 28, 1)) - images = images / 255.0 - return images + images = images.reshape((-1, 28, 28, 1)) + images = images / 255.0 + return images -(train_images, train_labels), (test_images, test_labels) = datasets.fashion_mnist.load_data() + +(train_images, train_labels), (test_images, test_labels) = ( + datasets.fashion_mnist.load_data() +) train_images = process_images(train_images) test_images = process_images(test_images) model = models.Sequential() -model.add(layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1))) +model.add( + layers.Conv2D( + filters=6, kernel_size=(5, 5), activation="relu", input_shape=(28, 28, 1) + ) +) model.add(layers.AveragePooling2D()) -model.add(layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu')) +model.add(layers.Conv2D(filters=16, kernel_size=(5, 5), activation="relu")) model.add(layers.AveragePooling2D()) model.add(layers.Flatten()) -model.add(layers.Dense(units=120, activation='relu')) -model.add(layers.Dense(units=84, activation='relu')) -model.add(layers.Dense(units=10, activation = 'softmax')) +model.add(layers.Dense(units=120, activation="relu")) +model.add(layers.Dense(units=84, activation="relu")) +model.add(layers.Dense(units=10, activation="softmax")) -model.compile(optimizer='adam', - loss=losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=['accuracy']) +model.compile( + optimizer="adam", + loss=losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=["accuracy"], +) # Save the untrained weights for future training with modified dataset -model.fit(train_images, train_labels, batch_size=100, epochs=10, - validation_data=(test_images, test_labels)) - -model.save('./lenet-fmnist.tf') +model.fit( + train_images, + train_labels, + batch_size=100, + epochs=10, + validation_data=(test_images, test_labels), +) + +model.save("./lenet-fmnist.tf") diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/stb_image.h b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/stb_image.h index 1b7e7b02..5b891039 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/stb_image.h +++ b/sample_programs/ml_sample_programs/vision_models/lenet-fmnist/stb_image.h @@ -3,7 +3,8 @@ Do this: #define STB_IMAGE_IMPLEMENTATION - before you include this file in *one* C or C++ file to create the implementation. + before you include this file in *one* C or C++ file to create the +implementation. // i.e. it should look like this: #include ... @@ -13,15 +14,16 @@ #include "stb_image.h" You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. - And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using +malloc,realloc,free QUICK NOTES: Primarily of interest to game developers and other people who can avoid problematic images and only need the trivial interface - JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) - PNG 1/2/4/8/16-bit-per-channel + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as +stock IJG lib) PNG 1/2/4/8/16-bit-per-channel TGA (not sure what subset, if a subset) BMP non-1bpp, non-RLE @@ -50,25 +52,22 @@ RECENT REVISION HISTORY: 2.26 (2020-07-13) many minor fixes 2.25 (2020-02-02) fix warnings - 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically - 2.23 (2019-08-11) fix clang static analysis warning - 2.22 (2019-03-04) gif fixes, fix warnings - 2.21 (2019-02-25) fix typo in comment - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and +flip_vertically 2.23 (2019-08-11) fix clang static analysis warning 2.22 +(2019-03-04) gif fixes, fix warnings 2.21 (2019-02-25) fix typo in comment 2.20 +(2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix warnings 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings - 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes - 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 - RGB-format JPEG; remove white matting in PSD; - allocate large structures on the stack; - correct channel count for PNG & BMP - 2.10 (2016-01-22) avoid warning introduced in 2.09 - 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; +bugfixes 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE +detection on GCC 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for +Imagenet JPGs 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; +fixes 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 +(2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 RGB-format JPEG; remove +white matting in PSD; allocate large structures on the stack; correct channel +count for PNG & BMP 2.10 (2016-01-22) avoid warning introduced in 2.09 2.09 +(2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED See end of file for full revision history. @@ -86,38 +85,37 @@ RECENT REVISION HISTORY: github:urraka (animated gif) Junggon Kim (PNM comments) Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) socks-the-fox (16-bit PNG) - Jeremy Sawicki (handle all ImageNet JPGs) - Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Jeremy Sawicki (handle all ImageNet +JPGs) Optimizations & bugfixes Mikhail Morozov (1-bit BMP) Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) Arseny Kapoulkine John-Mark Allen Carmelo J Fdez-Aguera Bug & warning fixes - Marc LeBlanc David Woo Guillaume George Martins Mozeiko - Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski - Phil Jordan Dave Moore Roy Eltham - Hayaki Saito Nathan Reed Won Chun - Luke Graham Johan Duparc Nick Verigakis the Horde3D community - Thomas Ruf Ronny Chevalier github:rlyeh - Janez Zemva John Bartholomew Michal Cichon github:romigrou - Jonathan Blow Ken Hamada Tero Hanninen github:svdijk - Laurent Gomila Cort Stratton github:snagar - Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex - Cass Everitt Ryamond Barbiero github:grim210 - Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw - Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus - Josh Tobin Matthew Gregan github:poppolopoppo - Julian Raschke Gregory Mullen Christian Floisand github:darealshinji - Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 - Brad Weinberger Matvey Cherevko [reserved] - Luca Sas Alexander Veselov Zack Middleton [reserved] + Marc LeBlanc David Woo Guillaume George Martins +Mozeiko Christpher Lloyd Jerry Jansson Joseph Thomson Blazej +Dariusz Roszkowski Phil Jordan Dave Moore Roy +Eltham Hayaki Saito Nathan Reed Won Chun Luke Graham Johan +Duparc Nick Verigakis the Horde3D community Thomas Ruf Ronny +Chevalier github:rlyeh Janez Zemva John +Bartholomew Michal Cichon github:romigrou Jonathan Blow Ken +Hamada Tero Hanninen github:svdijk Laurent Gomila Cort +Stratton github:snagar Aruelien Pocheville Sergio Gonzalez Thibault +Reuille github:Zelex Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Matthew Gregan +github:poppolopoppo Julian Raschke Gregory Mullen Christian +Floisand github:darealshinji Baldur Karlsson Kevin Schmidt JR +Smith github:Michaelangel007 Brad Weinberger Matvey Cherevko +[reserved] Luca Sas Alexander Veselov Zack Middleton [reserved] Ryan C. Gordon [reserved] [reserved] DO NOT ADD YOUR NAME HERE - To add your name to the credits, pick a random blank space in the middle and fill it. - 80% of merge conflicts on stb PRs are due to people adding their name at the end - of the credits. + To add your name to the credits, pick a random blank space in the middle and +fill it. 80% of merge conflicts on stb PRs are due to people adding their name +at the end of the credits. */ #ifndef STBI_INCLUDE_STB_IMAGE_H @@ -136,14 +134,15 @@ RECENT REVISION HISTORY: // // ... process data if not NULL ... // // ... x = width, y = height, n = # 8-bit components per pixel ... // // ... replace '0' with '1'..'4' to force that many components per pixel -// // ... but 'n' will always be the number that it would have been if you said 0 -// stbi_image_free(data) +// // ... but 'n' will always be the number that it would have been if you +// said 0 stbi_image_free(data) // // Standard parameters: // int *x -- outputs image width in pixels // int *y -- outputs image height in pixels // int *channels_in_file -- outputs # of image components in image file -// int desired_channels -- if non-zero, # of image components requested in result +// int desired_channels -- if non-zero, # of image components requested in +// result // // The return value from an image loader is an 'unsigned char *' which points // to the pixel data, or NULL on an allocation failure or if the image is @@ -171,8 +170,8 @@ RECENT REVISION HISTORY: // and *x, *y, *channels_in_file will be unchanged. The function // stbi_failure_reason() can be queried for an extremely brief, end-user // unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS -// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly -// more user-friendly ones. +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get +// slightly more user-friendly ones. // // Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. // @@ -196,11 +195,12 @@ RECENT REVISION HISTORY: // 2. easy to maintain // 3. good performance // -// Sometimes I let "good performance" creep up in priority over "easy to maintain", -// and for best performance I may provide less-easy-to-use APIs that give higher -// performance, in addition to the easy-to-use ones. Nevertheless, it's important -// to keep in mind that from the standpoint of you, a client of this library, -// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// Sometimes I let "good performance" creep up in priority over "easy to +// maintain", and for best performance I may provide less-easy-to-use APIs that +// give higher performance, in addition to the easy-to-use ones. Nevertheless, +// it's important to keep in mind that from the standpoint of you, a client of +// this library, all you care about is #1 and #3, and stb libraries DO NOT +// emphasize #3 above all. // // Some secondary priorities arise directly from the first two, some of which // provide more explicit reasons why performance can't be emphasized. @@ -219,7 +219,8 @@ RECENT REVISION HISTORY: // overhead. // // The three functions you must define are "read" (reads some bytes of data), -// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the +// end). // // =========================================================================== // @@ -247,10 +248,11 @@ RECENT REVISION HISTORY: // HDR image support (disable by defining STBI_NO_HDR) // // stb_image supports loading HDR images in general, and currently the Radiance -// .HDR file format specifically. You can still load any file through the existing -// interface; if you attempt to load an HDR file, it will be automatically remapped -// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; -// both of these constants can be reconfigured through this interface: +// .HDR file format specifically. You can still load any file through the +// existing interface; if you attempt to load an HDR file, it will be +// automatically remapped to LDR, assuming gamma 2.2 and an arbitrary scale +// factor defaulting to 1; both of these constants can be reconfigured through +// this interface: // // stbi_hdr_to_ldr_gamma(2.2f); // stbi_hdr_to_ldr_scale(1.0f); @@ -342,14 +344,13 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum -{ - STBI_default = 0, // only used for desired_channels +enum { + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -377,11 +378,13 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct -{ - int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof) (void *user); // returns nonzero if we are at end of file/data +typedef struct { + int (*read)(void *user, char *data, + int size); // fill 'data' with 'size' bytes. return number of + // bytes actually read + void (*skip)(void *user, int n); // skip the next 'n' bytes, or 'unget' the + // last -n bytes if negative + int (*eof)(void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -389,21 +392,33 @@ typedef struct // 8-bits-per-channel interface // -STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); -// for stbi_load_from_file, file pointer is left pointing immediately after image +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after +// image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input); #endif //////////////////////////////////// @@ -411,12 +426,20 @@ STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wch // 16-bits-per-channel interface // -STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); #endif //////////////////////////////////// @@ -424,83 +447,102 @@ STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_i // float-per-channel interface // #ifndef STBI_NO_LINEAR - STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); - #ifndef STBI_NO_STDIO - STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); - #endif +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +#endif #endif #ifndef STBI_NO_HDR - STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); - STBIDEF void stbi_hdr_to_ldr_scale(float scale); +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR - STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); - STBIDEF void stbi_ldr_to_hdr_scale(float scale); +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename); -STBIDEF int stbi_is_hdr_from_file(FILE *f); +STBIDEF int stbi_is_hdr(char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO - // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char *stbi_failure_reason (void); +STBIDEF const char *stbi_failure_reason(void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free (void *retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, + int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, + void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit (char const *filename); -STBIDEF int stbi_is_16_bit_from_file(FILE *f); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit(char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif - - // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); // indicate whether we should process iphone images back to canonical format, // or just pass them through "as-is" STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); -// flip the image vertically, so the first pixel in the output array is the bottom left +// flip the image vertically, so the first pixel in the output array is the +// bottom left STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); -// as above, but only applies to images loaded on the thread that calls the function -// this function is only available if your compiler supports thread-local variables; -// calling it will fail to link if your compiler doesn't -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); +// as above, but only applies to images loaded on the thread that calls the +// function this function is only available if your compiler supports +// thread-local variables; calling it will fail to link if your compiler doesn't +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); // ZLIB client - used by PNG, available for other purposes -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header); STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); - -STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, + int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); #ifdef __cplusplus } @@ -513,52 +555,53 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ - || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ - || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ - || defined(STBI_ONLY_ZLIB) - #ifndef STBI_ONLY_JPEG - #define STBI_NO_JPEG - #endif - #ifndef STBI_ONLY_PNG - #define STBI_NO_PNG - #endif - #ifndef STBI_ONLY_BMP - #define STBI_NO_BMP - #endif - #ifndef STBI_ONLY_PSD - #define STBI_NO_PSD - #endif - #ifndef STBI_ONLY_TGA - #define STBI_NO_TGA - #endif - #ifndef STBI_ONLY_GIF - #define STBI_NO_GIF - #endif - #ifndef STBI_ONLY_HDR - #define STBI_NO_HDR - #endif - #ifndef STBI_ONLY_PIC - #define STBI_NO_PIC - #endif - #ifndef STBI_ONLY_PNM - #define STBI_NO_PNM - #endif -#endif - -#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) -#define STBI_NO_ZLIB +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || \ + defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || \ + defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ + defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) +#ifndef STBI_ONLY_JPEG +#define STBI_NO_JPEG +#endif +#ifndef STBI_ONLY_PNG +#define STBI_NO_PNG +#endif +#ifndef STBI_ONLY_BMP +#define STBI_NO_BMP +#endif +#ifndef STBI_ONLY_PSD +#define STBI_NO_PSD +#endif +#ifndef STBI_ONLY_TGA +#define STBI_NO_TGA +#endif +#ifndef STBI_ONLY_GIF +#define STBI_NO_GIF +#endif +#ifndef STBI_ONLY_HDR +#define STBI_NO_HDR +#endif +#ifndef STBI_ONLY_PIC +#define STBI_NO_PIC +#endif +#ifndef STBI_ONLY_PNM +#define STBI_NO_PNM +#endif #endif +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && \ + !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif +#include #include #include // ptrdiff_t on osx #include #include -#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -576,55 +619,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #define STBI_EXTERN extern #endif - #ifndef _MSC_VER - #ifdef __cplusplus - #define stbi_inline inline - #else - #define stbi_inline - #endif +#ifdef __cplusplus +#define stbi_inline inline +#else +#define stbi_inline +#endif #else - #define stbi_inline __forceinline +#define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS - #if defined(__cplusplus) && __cplusplus >= 201103L - #define STBI_THREAD_LOCAL thread_local - #elif defined(__GNUC__) && __GNUC__ < 5 - #define STBI_THREAD_LOCAL __thread - #elif defined(_MSC_VER) - #define STBI_THREAD_LOCAL __declspec(thread) - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) - #define STBI_THREAD_LOCAL _Thread_local - #endif - - #ifndef STBI_THREAD_LOCAL - #if defined(__GNUC__) - #define STBI_THREAD_LOCAL __thread - #endif - #endif +#if defined(__cplusplus) && __cplusplus >= 201103L +#define STBI_THREAD_LOCAL thread_local +#elif defined(__GNUC__) && __GNUC__ < 5 +#define STBI_THREAD_LOCAL __thread +#elif defined(_MSC_VER) +#define STBI_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && \ + !defined(__STDC_NO_THREADS__) +#define STBI_THREAD_LOCAL _Thread_local +#endif + +#ifndef STBI_THREAD_LOCAL +#if defined(__GNUC__) +#define STBI_THREAD_LOCAL __thread +#endif +#endif #endif #ifdef _MSC_VER typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -632,27 +675,30 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL - #define stbi_lrot(x,y) _lrotl(x,y) +#define stbi_lrot(x, y) _lrotl(x, y) #else - #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (32 - (y)))) +#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (32 - (y)))) #endif -#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +#if defined(STBI_MALLOC) && defined(STBI_FREE) && \ + (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) // ok -#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && \ + !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) // ok #else -#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#error \ + "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p,newsz) realloc(p,newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p, newsz) realloc(p, newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) #endif // x86/x64 detection @@ -662,7 +708,8 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI__X86_TARGET #endif -#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && \ + !defined(STBI_NO_SIMD) // gcc doesn't support sse2 intrinsics unless you compile with -msse2, // which in turn means it gets to use SSE2 everywhere. This is unfortunate, // but previous attempts to provide the SSE2 functions with runtime @@ -673,8 +720,10 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI_NO_SIMD #endif -#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) -// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && \ + !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid +// STBI__X64_TARGET // // 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the // Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. @@ -684,44 +733,43 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; // See https://github.com/nothings/stb/issues/81 for more information. // // So default to no SSE2 on 32-bit MinGW. If you've read this far and added -// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +// -mstackrealign to your build settings, feel free to #define +// STBI_MINGW_ENABLE_SSE2. #define STBI_NO_SIMD #endif -#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#if !defined(STBI_NO_SIMD) && \ + (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) #define STBI_SSE2 #include #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) -{ - int info[4]; - __cpuid(info,1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) { + int info[4]; + __cpuid(info, 1); + return info[3]; } #else -static int stbi__cpuid3(void) -{ - int res; - __asm { +static int stbi__cpuid3(void) { + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) { + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -729,12 +777,11 @@ static int stbi__sse2_available(void) #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) { + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -766,188 +813,182 @@ static int stbi__sse2_available(void) // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct -{ - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct { + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void *io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; - static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) -{ - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) { + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) -{ - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, + void *user) { + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void *user, char *data, int size) -{ - return (int) fread(data,1,size,(FILE*) user); +static int stbi__stdio_read(void *user, char *data, int size) { + return (int)fread(data, 1, size, (FILE *)user); } -static void stbi__stdio_skip(void *user, int n) -{ - int ch; - fseek((FILE*) user, n, SEEK_CUR); - ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ - } +static void stbi__stdio_skip(void *user, int n) { + int ch; + fseek((FILE *)user, n, SEEK_CUR); + ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ + } } -static int stbi__stdio_eof(void *user) -{ - return feof((FILE*) user) || ferror((FILE *) user); +static int stbi__stdio_eof(void *user) { + return feof((FILE *)user) || ferror((FILE *)user); } -static stbi_io_callbacks stbi__stdio_callbacks = -{ - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static stbi_io_callbacks stbi__stdio_callbacks = { + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context *s, FILE *f) -{ - stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +static void stbi__start_file(stbi__context *s, FILE *f) { + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } -//static void stop_file(stbi__context *s) { } +// static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context *s) -{ - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 + // bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum -{ - STBI_ORDER_RGB, - STBI_ORDER_BGR -}; +enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; -typedef struct -{ - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct { + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context *s); -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context *s); -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__png_is16(stbi__context *s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context *s); -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context *s); -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s); -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__psd_is16(stbi__context *s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context *s); -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context *s); -static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context *s); -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s); -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); #endif static #ifdef STBI_THREAD_LOCAL -STBI_THREAD_LOCAL + STBI_THREAD_LOCAL #endif -const char *stbi__g_failure_reason; + const char *stbi__g_failure_reason; -STBIDEF const char *stbi_failure_reason(void) -{ - return stbi__g_failure_reason; +STBIDEF const char *stbi_failure_reason(void) { + return stbi__g_failure_reason; } #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char *str) -{ - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) { + stbi__g_failure_reason = str; + return 0; } #endif -static void *stbi__malloc(size_t size) -{ - return STBI_MALLOC(size); +static void *stbi__malloc(size_t size) { + return STBI_MALLOC(size); } // stb_image uses ints pervasively, including for offset calculations. @@ -962,70 +1003,72 @@ static void *stbi__malloc(size_t size) // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) -{ - if (b < 0) return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) { + if (b < 0) + return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) -{ - if (a < 0 || b < 0) return 0; - if (b == 0) return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX/b; +static int stbi__mul2sizes_valid(int a, int b) { + if (a < 0 || b < 0) + return 0; + if (b == 0) + return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; } -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__addsizes_valid(a*b*c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__addsizes_valid(a * b * c, add); } -// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't +// overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__mul2sizes_valid(a * b * c, d) && + stbi__addsizes_valid(a * b * c * d, add); } #endif -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void *stbi__malloc_mad2(int a, int b, int add) -{ - if (!stbi__mad2sizes_valid(a, b, add)) return NULL; - return stbi__malloc(a*b + add); +static void *stbi__malloc_mad2(int a, int b, int add) { + if (!stbi__mad2sizes_valid(a, b, add)) + return NULL; + return stbi__malloc(a * b + add); } #endif -static void *stbi__malloc_mad3(int a, int b, int c, int add) -{ - if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; - return stbi__malloc(a*b*c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) { + if (!stbi__mad3sizes_valid(a, b, c, add)) + return NULL; + return stbi__malloc(a * b * c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) -{ - if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; - return stbi__malloc(a*b*c*d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) { + if (!stbi__mad4sizes_valid(a, b, c, d, add)) + return NULL; + return stbi__malloc(a * b * c * d + add); } #endif @@ -1034,417 +1077,459 @@ static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS - #define stbi__err(x,y) 0 +#define stbi__err(x, y) 0 #elif defined(STBI_FAILURE_USERMSG) - #define stbi__err(x,y) stbi__err(y) +#define stbi__err(x, y) stbi__err(y) #else - #define stbi__err(x,y) stbi__err(x) +#define stbi__err(x, y) stbi__err(x) #endif -#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) -#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpuc(x, y) \ + ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -STBIDEF void stbi_image_free(void *retval_from_stbi_load) -{ - STBI_FREE(retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load) { + STBI_FREE(retval_from_stbi_load); } #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else -static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, + stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ - ? stbi__vertically_flip_on_load_local \ - : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load \ + (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; - - #ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNG - if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_GIF - if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PSD - if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); - #else - STBI_NOTUSED(bpc); - #endif - #ifndef STBI_NO_PIC - if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); - #endif - - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } - #endif - - #ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s,x,y,comp,req_comp, ri); - #endif - - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); -} - -static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi_uc *reduced; - - reduced = (stbi_uc *) stbi__malloc(img_len); - if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - - STBI_FREE(orig); - return reduced; -} - -static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi__uint16 *enlarged; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = + 8; // default is 8 so most paths don't have to be changed + ri->channel_order = + STBI_ORDER_RGB; // all current input & output are this, but this is here + // so we can add BGR order + ri->num_channels = 0; - enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); - if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); +#ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) + return stbi__jpeg_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNG + if (stbi__png_test(s)) + return stbi__png_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) + return stbi__bmp_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_GIF + if (stbi__gif_test(s)) + return stbi__gif_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PSD + if (stbi__psd_test(s)) + return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); +#else + STBI_NOTUSED(bpc); +#endif +#ifndef STBI_NO_PIC + if (stbi__pic_test(s)) + return stbi__pic_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) + return stbi__pnm_load(s, x, y, comp, req_comp, ri); +#endif - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } +#endif - STBI_FREE(orig); - return enlarged; -} +#ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s, x, y, comp, req_comp, ri); +#endif -static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) -{ - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc *bytes = (stbi_uc *)image; - - for (row = 0; row < (h>>1); row++) { - stbi_uc *row0 = bytes + row*bytes_per_row; - stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } + return stbi__errpuc("unknown image type", + "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *)stbi__malloc(img_len); + if (reduced == NULL) + return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = + (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient + // approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); + if (enlarged == NULL) + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + + orig[i]); // replicate to high and low byte, + // maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, + int bytes_per_pixel) { + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h >> 1); row++) { + stbi_uc *row0 = bytes + row * bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1) * bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = + (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) -{ - int slice; - int slice_size = w * h * bytes_per_pixel; +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, + int bytes_per_pixel) { + int slice; + int slice_size = w * h * bytes_per_pixel; - stbi_uc *bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *) result; + return (unsigned char *)result; } -static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to + // keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *) result; + return (stbi__uint16 *)result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) -{ - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, + int req_comp) { + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar( + unsigned int cp, unsigned long flags, const char *str, int cbmb, + wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte( + unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, + char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) -{ - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input) { + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, + (int)bufferlen, NULL, NULL); } #endif -static FILE *stbi__fopen(char const *filename, char const *mode) -{ - FILE *f; +static FILE *stbi__fopen(char const *filename, char const *mode) { + FILE *f; #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, + sizeof(wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) - return 0; + if (0 == + MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) + return 0; #if _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f=0; + if (0 != fopen_s(&f, filename, mode)) + f = 0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; -} - - -STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - unsigned char *result; - if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; -} - -STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__uint16 *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - stbi__uint16 *result; - if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f,x,y,comp,req_comp); - fclose(f); - return result; -} - - -#endif //!STBI_NO_STDIO - -STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); -} - -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + return f; +} + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) + return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) + return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +#endif //! STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_mem(&s,buffer,len); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_mem(&s, buffer, len); - result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); - } + result = + (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices(result, *x, *y, *z, *comp); + } - return result; + return result; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *data; - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data,x,y,comp,req_comp); - return hdr_data; - } - #endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); -} - -STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__loadf_main(&s,x,y,comp,req_comp); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp) { + unsigned char *data; +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data, x, y, comp, req_comp); + return hdr_data; + } +#endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", + "Image not of any known type, or corrupt"); } -STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__loadf_main(&s,x,y,comp,req_comp); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -#ifndef STBI_NO_STDIO -STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - float *result; - FILE *f = stbi__fopen(filename, "rb"); - if (!f) return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_file(&s,f); - return stbi__loadf_main(&s,x,y,comp,req_comp); +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, + int req_comp) { + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) + return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_file(&s, f); + return stbi__loadf_main(&s, x, y, comp, req_comp); } #endif // !STBI_NO_STDIO @@ -1454,221 +1539,222 @@ STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_ // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; +#endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result=0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result = 0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE *f) -{ - #ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s,f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; - #else - STBI_NOTUSED(f); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) { +#ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s, f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; +#else + STBI_NOTUSED(f); + return 0; +#endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; +#endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; +static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { + stbi__l2h_gamma = gamma; +} +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { + stbi__l2h_scale = scale; +} #endif -static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; - -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } +static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { + stbi__h2l_gamma_i = 1 / gamma; +} +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { + stbi__h2l_scale_i = 1 / scale; +} ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum -{ - STBI__SCAN_load=0, - STBI__SCAN_type, - STBI__SCAN_header -}; - -static void stbi__refill_buffer(stbi__context *s) -{ - int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); - s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start+1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } -} - -stbi_inline static stbi_uc stbi__get8(stbi__context *s) -{ - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; -} - -#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; + +static void stbi__refill_buffer(stbi__context *s) { + int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context *s) -{ - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) + return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) + return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context *s, int n) -{ - if (n == 0) return; // already there! - if (n < 0) { +static void stbi__skip(stbi__context *s, int n) { + if (n == 0) + return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); return; - } - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; + } + } + s->img_buffer += n; } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && \ + defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) -{ - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) { + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; - memcpy(buffer, s->img_buffer, blen); + memcpy(buffer, s->img_buffer, blen); - count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); - res = (count == (n-blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } + count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); + res = (count == (n - blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } - if (s->img_buffer+n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + if (s->img_buffer + n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context *s) -{ - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) { + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context *s) -{ - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) { + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context *s) -{ - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) { + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context *s) -{ - stbi__uint32 z = stbi__get16le(s); - return z + (stbi__get16le(s) << 16); +static stbi__uint32 stbi__get32le(stbi__context *s) { + stbi__uint32 z = stbi__get16le(s); + return z + (stbi__get16le(s) << 16); } #endif -#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) \ + ((stbi_uc)((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1682,169 +1768,301 @@ static stbi__uint32 stbi__get32le(stbi__context *s) // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) -{ - return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi_uc stbi__compute_y(int r, int g, int b) { + return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - unsigned char *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - unsigned char *src = data + j * x * img_n ; - unsigned char *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + unsigned char *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + unsigned char *src = data + j * x * img_n; + unsigned char *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 255; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 255; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 255; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = 255; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) -{ - return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { + return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - stbi__uint16 *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - stbi__uint16 *src = data + j * x * img_n ; - stbi__uint16 *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + stbi__uint16 *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + stbi__uint16 *src = data + j * x * img_n; + stbi__uint16 *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 0xffff; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 0xffff; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 0xffff; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = 0xffff; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return (stbi__uint16 *)stbi__errpuc("unsupported", + "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) -{ - int i,k,n; - float *output; - if (!data) return NULL; - output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i=0; i < x*y; ++i) { - output[i*comp + n] = data[i*comp + n]/255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) { + int i, k, n; + float *output; + if (!data) + return NULL; + output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpf("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + output[i * comp + k] = + (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * + stbi__l2h_scale); + } + } + if (n < comp) { + for (i = 0; i < x * y; ++i) { + output[i * comp + n] = data[i * comp + n] / 255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int) (x)) -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) -{ - int i,k,n; - stbi_uc *output; - if (!data) return NULL; - output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - if (k < comp) { - float z = data[i*comp+k] * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int)(x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) { + int i, k, n; + stbi_uc *output; + if (!data) + return NULL; + output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, + stbi__h2l_gamma_i) * + 255 + + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + if (k < comp) { + float z = data[i * comp + k] * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1872,750 +2090,791 @@ static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache - -typedef struct -{ - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct { + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct -{ - stbi__context *s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; - -// sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; - -// definition of jpeg image component - struct - { - int id; - int h,v; - int tq; - int hd,ha; - int dc_pred; - - int x,y,w2,h2; - stbi_uc *data; - void *raw_data, *raw_coeff; - stbi_uc *linebuf; - short *coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; - - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop - - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; - - int scan_n, order[4]; - int restart_interval, todo; - -// kernels - void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); - stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +typedef struct { + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, int count, + int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman *h, int *count) -{ - int i,j,k=0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i=0; i < 16; ++i) - for (j=0; j < count[i]; ++j) - h->size[k++] = (stbi_uc) (i+1); - h->size[k] = 0; - - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for(j=1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16) (code++); - if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16-j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; - - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i=0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS-s); - int m = 1 << (FAST_BITS-s); - for (j=0; j < m; ++j) { - h->fast[c+j] = (stbi_uc) i; - } +static int stbi__build_huffman(stbi__huffman *h, int *count) { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) + for (j = 0; j < count[i]; ++j) + h->size[k++] = (stbi_uc)(i + 1); + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16)(code++); + if (code - 1 >= (1u << j)) + return stbi__err("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (stbi_uc)i; } - } - return 1; + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) -{ - int i; - for (i=0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; - - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); - } +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) + k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); } - } -} - -static void stbi__grow_buffer_unsafe(stbi__jpeg *j) -{ - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char) c; - j->nomore = 1; - return; - } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) { + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) + c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; +static const stbi__uint32 stbi__bmask[17] = { + 0, 1, 3, 7, 15, 31, 63, 127, 255, + 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) -{ - unsigned int temp; - int c,k; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } - - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k=FAST_BITS+1 ; ; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } - - if (k > j->code_bits) +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) return -1; - - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1;; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & + stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); - - sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB - k = stbi_lrot(j->code_buffer, n); - if (n < 0 || n >= (int) (sizeof(stbi__bmask)/sizeof(*stbi__bmask))) return 0; - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & ~sgn); +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) { + unsigned int k; + int sgn; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + + sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB + k = stbi_lrot(j->code_buffer, n); + if (n < 0 || n >= (int)(sizeof(stbi__bmask) / sizeof(*stbi__bmask))) + return 0; + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & ~sgn); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) -{ - unsigned int k; - if (j->code_bits < n) stbi__grow_buffer_unsafe(j); - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; -} - -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) -{ - unsigned int k; - if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) { + unsigned int k; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) { + unsigned int k; + if (j->code_bits < 1) + stbi__grow_buffer_unsafe(j); + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64+15] = -{ - 0, 1, 8, 16, 9, 2, 3, 10, - 17, 24, 32, 25, 18, 11, 4, 5, - 12, 19, 26, 33, 40, 48, 41, 34, - 27, 20, 13, 6, 7, 14, 21, 28, - 35, 42, 49, 56, 57, 50, 43, 36, - 29, 22, 15, 23, 30, 37, 44, 51, - 58, 59, 52, 45, 38, 31, 39, 46, - 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, - 63, 63, 63, 63, 63, 63, 63 -}; +static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { + 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, + 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, + 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) -{ - int diff,dc,k; - int t; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - - // 0 all the ac values now so we can do it 32-bits at a time - memset(data,0,64*sizeof(data[0])); - - diff = t ? stbi__extend_receive(j, t) : 0; - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc * dequant[0]); - - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) * dequant[zig]); +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, stbi__huffman *hac, + stbi__int16 *fac, int b, + stbi__uint16 *dequant) { + int diff, dc, k; + int t; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) + break; // end block + k += 16; } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (rs != 0xf0) break; // end block - k += 16; - } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); - } + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); } - } while (k < 64); - return 1; -} - -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) -{ - int diff,dc; - int t; - if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data,0,64*sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t == -1) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; - - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc << j->succ_low); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short) (1 << j->succ_low); - } - return 1; + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, int b) { + int diff, dc; + int t; + if (j->spec_end != 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t == -1) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc << j->succ_low); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; } // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) -{ - int k; - if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->succ_high == 0) { - int shift = j->succ_low; +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], + stbi__huffman *hac, + stbi__int16 *fac) { + int k; + if (j->spec_start == 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } - if (j->eob_run) { - --j->eob_run; - return 1; + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) << shift); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) << shift); + } } - + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { k = j->spec_start; do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) << shift); - } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) << shift); - } - } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients - - short bit = (short) (1 << j->succ_low); - - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short *p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } - } else { - k = j->spec_start; - do { - int r,s; - int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } + int r, s; + int rs = stbi__jpeg_huff_decode( + j, hac); // @OPTIMIZE see if we can use the fast path here, + // advance-by-r is so slow, eh + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) + return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } - // advance by r - while (k <= j->spec_end) { - short *p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short) s; - break; - } - --r; - } + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short)s; + break; } - } while (k <= j->spec_end); - } - } - return 1; + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) -{ - // trick to use a single test to catch both cases - if ((unsigned int) x > 255) { - if (x < 0) return 0; - if (x > 255) return 255; - } - return (stbi_uc) x; +stbi_inline static stbi_uc stbi__clamp(int x) { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) + return 0; + if (x > 255) + return 255; + } + return (stbi_uc)x; } -#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) -#define stbi__fsh(x) ((x) * 4096) +#define stbi__f2f(x) ((int)(((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ - int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2+p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3*stbi__f2f(-1.847759065f); \ - t3 = p1 + p2*stbi__f2f( 0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2+p3); \ - t1 = stbi__fsh(p2-p3); \ - x0 = t0+t3; \ - x3 = t0-t3; \ - x1 = t1+t2; \ - x2 = t1-t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0+t2; \ - p4 = t1+t3; \ - p1 = t0+t3; \ - p2 = t1+t2; \ - p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ - t0 = t0*stbi__f2f( 0.298631336f); \ - t1 = t1*stbi__f2f( 2.053119869f); \ - t2 = t2*stbi__f2f( 3.072711026f); \ - t3 = t3*stbi__f2f( 1.501321110f); \ - p1 = p5 + p1*stbi__f2f(-0.899976223f); \ - p2 = p5 + p2*stbi__f2f(-2.562915447f); \ - p3 = p3*stbi__f2f(-1.961570560f); \ - p4 = p4*stbi__f2f(-0.390180644f); \ - t3 += p1+p4; \ - t2 += p2+p3; \ - t1 += p2+p4; \ - t0 += p1+p3; - -static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) -{ - int i,val[64],*v=val; - stbi_uc *o; - short *d = data; - - // columns - for (i=0; i < 8; ++i,++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 - && d[40]==0 && d[48]==0 && d[56]==0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0]*4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; x1 += 512; x2 += 512; x3 += 512; - v[ 0] = (x0+t3) >> 10; - v[56] = (x0-t3) >> 10; - v[ 8] = (x1+t2) >> 10; - v[48] = (x1-t2) >> 10; - v[16] = (x2+t1) >> 10; - v[40] = (x2-t1) >> 10; - v[24] = (x3+t0) >> 10; - v[32] = (x3-t0) >> 10; - } - } - - for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128<<17); - x1 += 65536 + (128<<17); - x2 += 65536 + (128<<17); - x3 += 65536 + (128<<17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0+t3) >> 17); - o[7] = stbi__clamp((x0-t3) >> 17); - o[1] = stbi__clamp((x1+t2) >> 17); - o[6] = stbi__clamp((x1-t2) >> 17); - o[2] = stbi__clamp((x2+t1) >> 17); - o[5] = stbi__clamp((x2-t1) >> 17); - o[3] = stbi__clamp((x3+t0) >> 17); - o[4] = stbi__clamp((x3-t0) >> 17); - } +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ + t0 += p1 + p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) { + int i, val[64], *v = val; + stbi_uc *o; + short *d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && + d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; + x1 += 512; + x2 += 512; + x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0 + t3) >> 17); + o[7] = stbi__clamp((x0 - t3) >> 17); + o[1] = stbi__clamp((x1 + t2) >> 17); + o[6] = stbi__clamp((x1 - t2) >> 17); + o[2] = stbi__clamp((x2 + t1) >> 17); + o[5] = stbi__clamp((x2 - t1) >> 17); + o[3] = stbi__clamp((x3 + t0) >> 17); + o[4] = stbi__clamp((x3 - t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; - - // dot product constant: even elems=x, odd elems=y - #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) - - // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) - // out(1) = c1[even]*x + c1[odd]*y - #define dct_rot(out0,out1, x,y,c0,c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) - - // out = in << 12 (in 16-bit, out 32-bit) - #define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - - // wide add - #define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - - // wide sub - #define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) - - // butterfly a/b, add bias, then shift by "s" and pack - #define dct_bfly32o(out0, out1, a,b,bias,s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } - - // 8-bit interleave step (for transposes) - #define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) - - // 16-bit interleave step (for transposes) - #define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) - - #define dct_pass(bias,shift) \ - { \ - /* even part */ \ - dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ - dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0,row7, x0,x7,bias,shift); \ - dct_bfly32o(row1,row6, x1,x6,bias,shift); \ - dct_bfly32o(row2,row5, x2,x5,bias,shift); \ - dct_bfly32o(row3,row4, x3,x4,bias,shift); \ - } +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + +// dot product constant: even elems=x, odd elems=y +#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + +// out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = \ + _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = \ + _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); - - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); - - // load - row0 = _mm_load_si128((const __m128i *) (data + 0*8)); - row1 = _mm_load_si128((const __m128i *) (data + 1*8)); - row2 = _mm_load_si128((const __m128i *) (data + 2*8)); - row3 = _mm_load_si128((const __m128i *) (data + 3*8)); - row4 = _mm_load_si128((const __m128i *) (data + 4*8)); - row5 = _mm_load_si128((const __m128i *) (data + 5*8)); - row6 = _mm_load_si128((const __m128i *) (data + 6*8)); - row7 = _mm_load_si128((const __m128i *) (data + 7*8)); - - // column pass - dct_pass(bias_0, 10); - - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); - - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); - - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } - - // row pass - dct_pass(bias_1, 17); - - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); - - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... - - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... - - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... +// wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - // store - _mm_storel_epi64((__m128i *) out, p0); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p2); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p1); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p3); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); - } +// wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + +// butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = \ + _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = \ + _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + +// 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + +// 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), + stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), + stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), + stbi__f2f(1.175875602f)); + __m128i rot1_1 = + dct_const(stbi__f2f(1.175875602f), + stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), + stbi__f2f(-1.961570560f)); + __m128i rot2_1 = + dct_const(stbi__f2f(-1.961570560f), + stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), + stbi__f2f(-0.390180644f)); + __m128i rot3_1 = + dct_const(stbi__f2f(-0.390180644f), + stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); + row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); + row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); + row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); + row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); + row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); + row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); + row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *)out, p0); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p2); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p1); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p3); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2634,198 +2893,236 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); - -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) - -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) - -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ - dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ - dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ - dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ - } - - // load - row0 = vld1q_s16(data + 0*8); - row1 = vld1q_s16(data + 1*8); - row2 = vld1q_s16(data + 2*8); - row3 = vld1q_s16(data + 3*8); - row4 = vld1q_s16(data + 4*8); - row5 = vld1q_s16(data + 5*8); - row6 = vld1q_s16(data + 6*8); - row7 = vld1q_s16(data + 7*8); - - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - - // column pass - dct_pass(vrshrn_n_s32, 10); - - // 16bit 8x8 transpose - { +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ + } + + // load + row0 = vld1q_s16(data + 0 * 8); + row1 = vld1q_s16(data + 1 * 8); + row2 = vld1q_s16(data + 2 * 8); + row3 = vld1q_s16(data + 3 * 8); + row4 = vld1q_s16(data + 4 * 8); + row5 = vld1q_s16(data + 5 * 8); + row6 = vld1q_s16(data + 6 * 8); + row7 = vld1q_s16(data + 7 * 8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } -#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } - - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); - - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); - - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = \ + vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ + } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } - - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); - - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } -#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } - - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! - - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); - - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); - - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); - - // store - vst1_u8(out, p0); out += out_stride; - vst1_u8(out, p1); out += out_stride; - vst1_u8(out, p2); out += out_stride; - vst1_u8(out, p3); out += out_stride; - vst1_u8(out, p4); out += out_stride; - vst1_u8(out, p5); out += out_stride; - vst1_u8(out, p6); out += out_stride; - vst1_u8(out, p7); + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ + } +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ + } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); + out += out_stride; + vst1_u8(out, p1); + out += out_stride; + vst1_u8(out, p2); + out += out_stride; + vst1_u8(out, p3); + out += out_stride; + vst1_u8(out, p4); + out += out_stride; + vst1_u8(out, p5); + out += out_stride; + vst1_u8(out, p6); + out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -2838,1132 +3135,1274 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg *j) -{ - stbi_uc x; - if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } - x = stbi__get8(j->s); - if (x != 0xff) return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) { + stbi_uc x; + if (j->marker != STBI__MARKER_none) { + x = j->marker; + j->marker = STBI__MARKER_none; + return x; + } + x = stbi__get8(j->s); + if (x != 0xff) + return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg *j) -{ - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels -} - -static int stbi__parse_entropy_coded_data(stbi__jpeg *z) -{ - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i,j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - STBI_SIMD_ALIGN(short, data[64]); - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x)*8; - int y2 = (j*z->img_comp[n].v + y)*8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; +static void stbi__jpeg_reset(stbi__jpeg *j) { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = + j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) { + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } else { - if (z->scan_n == 1) { - int i,j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x); - int y2 = (j*z->img_comp[n].v + y); - short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } + return 1; + } else { // interleaved + int i, j, k, x, y; + STBI_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, + z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + + z->img_comp[n].w2 * y2 + x2, + z->img_comp[n].w2, data); + } } - } - return 1; + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) -{ - int i; - for (i=0; i < 64; ++i) - data[i] *= dequant[i]; -} - -static void stbi__jpeg_finish(stbi__jpeg *z) -{ - if (z->progressive) { - // dequantize and idct the data - int i,j,n; - for (n=0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - } - } + return 1; + } + } else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], + z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static int stbi__process_marker(stbi__jpeg *z, int m) -{ - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker","Corrupt JPEG"); - - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; - - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s)-2; - while (L > 0) { - int q = stbi__get8(z->s); - int p = q >> 4, sixteen = (p != 0); - int t = q & 15,i; - if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); - if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - - for (i=0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); - L -= (sixteen ? 129 : 65); - } - return L==0; - - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s)-2; - while (L > 0) { - stbi_uc *v; - int sizes[16],i,n=0; - int q = stbi__get8(z->s); - int tc = q >> 4; - int th = q & 15; - if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); - for (i=0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; - } - L -= 17; - if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; - v = z->huff_dc[th].values; - } else { - if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; - v = z->huff_ac[th].values; + return 1; + } else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } } - for (i=0; i < n; ++i) - v[i] = stbi__get8(z->s); - if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); - L -= n; - } - return L==0; - } - - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len","Corrupt JPEG"); - else - return stbi__err("bad APP len","Corrupt JPEG"); + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - L -= 2; - - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J','F','I','F','\0'}; - int ok = 1; - int i; - for (i=0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; - int ok = 1; - int i; - for (i=0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + } } + } + } +} - stbi__skip(z->s, L); - return 1; - } +static int stbi__process_marker(stbi__jpeg *z, int m) { + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker", "Corrupt JPEG"); - return stbi__err("unknown marker","Corrupt JPEG"); -} + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) + return stbi__err("bad DRI len", "Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; -// after we see SOS -static int stbi__process_scan_header(stbi__jpeg *z) -{ - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); - if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); - for (i=0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s) - 2; + while (L > 0) { int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) return 0; // no match - z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); - z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); - z->order[i] = which; - } - - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) + return stbi__err("bad DQT type", "Corrupt JPEG"); + if (t > 3) + return stbi__err("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = + (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + stbi_uc *v; + int sizes[16], i, n = 0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) + return stbi__err("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc + th, sizes)) + return 0; + v = z->huff_dc[th].values; } else { - if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); - z->spec_end = 63; + if (!stbi__build_huffman(z->huff_ac + th, sizes)) + return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len", "Corrupt JPEG"); + else + return stbi__err("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; } - } + } + + stbi__skip(z->s, L); + return 1; + } - return 1; + return stbi__err("unknown marker", "Corrupt JPEG"); } -static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) -{ - int i; - for (i=0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) { + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) + return stbi__err("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) + return stbi__err("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) + return 0; // no match + z->img_comp[which].hd = q >> 4; + if (z->img_comp[which].hd > 3) + return stbi__err("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; + if (z->img_comp[which].ha > 3) + return stbi__err("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || + z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; } -static int stbi__process_frame_header(stbi__jpeg *z, int scan) -{ - stbi__context *s = z->s; - int Lf,p,i,q, h_max=1,v_max=1,c; - Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG - p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); - s->img_n = c; - for (i=0; i < c; ++i) { +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } - - if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); - - z->rgb = 0; - for (i=0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = { 'R', 'G', 'B' }; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); - z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); - } - - if (scan != STBI__SCAN_load) return 1; - - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); - - for (i=0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; - } - - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; - - for (i=0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); - } - } + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) { + stbi__context *s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = stbi__get16be(s); + if (Lf < 11) + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + p = stbi__get8(s); + if (p != 8) + return stbi__err("only 8-bit", + "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); + if (s->img_y == 0) + return stbi__err( + "no header height", + "JPEG format not supported: delayed height"); // Legal, but we don't + // handle it--but neither + // does IJG + s->img_x = stbi__get16be(s); + if (s->img_x == 0) + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) + return stbi__err("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) + return stbi__err("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = {'R', 'G', 'B'}; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); + if (!z->img_comp[i].h || z->img_comp[i].h > 4) + return stbi__err("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; + if (!z->img_comp[i].v || z->img_comp[i].v > 4) + return stbi__err("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); + if (z->img_comp[i].tq > 3) + return stbi__err("bad TQ", "Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) + return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) + return stbi__err("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) + h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) + v_max = z->img_comp[i].v; + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked + // earlier) so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = + stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i + 1, + stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = + (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3( + z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components( + z, i + 1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = + (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } - return 1; + return 1; } // use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) - -#define stbi__SOF_progressive(x) ((x) == 0xc2) - -static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) -{ - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); - if (scan == STBI__SCAN_type) return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z,m)) return 0; +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) + return stbi__err("no SOI", "Corrupt JPEG"); + if (scan == STBI__SCAN_type) + return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z, m)) + return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) + return stbi__err("no SOF", "Corrupt JPEG"); m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) return 0; - return 1; + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) + return 0; + return 1; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg *j) -{ - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) return 0; - if (!stbi__parse_entropy_coded_data(j)) return 0; - if (j->marker == STBI__MARKER_none ) { - // handle 0s at the end of image data from IP Kamera 9060 - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - if (x == 255) { - j->marker = stbi__get8(j->s); - break; - } - } - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); - } else { - if (!stbi__process_marker(j, m)) return 0; +static int stbi__decode_jpeg_image(stbi__jpeg *j) { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) + return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) + return 0; + if (!stbi__parse_entropy_coded_data(j)) + return 0; + if (j->marker == STBI__MARKER_none) { + // handle 0s at the end of image data from IP Kamera 9060 + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + if (x == 255) { + j->marker = stbi__get8(j->s); + break; + } + } + // if we reach eof without hitting a marker, stbi__get_marker() below + // will fail and we'll eventually return 0 } - m = stbi__get_marker(j); - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) + return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) + return stbi__err("bad DNL height", "Corrupt JPEG"); + } else { + if (!stbi__process_marker(j, m)) + return 0; + } + m = stbi__get_marker(j); + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, - int w, int hs); - -#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) - -static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; -} - -static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i=0; i < w; ++i) - out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); - return out; + int w, int hs); + +#define stbi__div4(x) ((stbi_uc)((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, + int w, int hs) { + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc *stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc *stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = stbi__div4(n + input[i - 1]); + out[i * 2 + 1] = stbi__div4(n + input[i + 1]); + } + out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc)((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = stbi__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; } -static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples horizontally for every one in input - int i; - stbi_uc *input = in_near; - - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } - - out[0] = input[0]; - out[1] = stbi__div4(input[0]*3 + input[1] + 2); - for (i=1; i < w-1; ++i) { - int n = 3*input[i]+2; - out[i*2+0] = stbi__div4(n+input[i-1]); - out[i*2+1] = stbi__div4(n+input[i+1]); - } - out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); - out[i*2+1] = input[w-1]; - - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); - - return out; -} +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = + _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *)(out + i * 2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = + vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif -#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } -static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i,t0,t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - t1 = 3*in_near[0] + in_far[0]; - out[0] = stbi__div4(t1+2); - for (i=1; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } +#endif -#if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i=0,t0,t1; - - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } - - t1 = 3*in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w-1) & ~7); i += 8) { -#if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); - - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); - - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *) (out + i*2), outv); -#elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); - - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i*2, o); -#endif - - // "previous" value for next iter - t1 = 3*in_near[i+7] + in_far[i+7]; - } - - t0 = t1; - t1 = 3*in_near[i] + in_far[i]; - out[i*2] = stbi__div16(3*t1 + t0 + 8); - - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); - - STBI_NOTUSED(hs); - - return out; -} -#endif - -static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // resample with nearest-neighbor - int i,j; - STBI_NOTUSED(in_far); - for (i=0; i < w; ++i) - for (j=0; j < hs; ++j) - out[i*hs+j] = in_near[i]; - return out; +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // resample with nearest-neighbor + int i, j; + STBI_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) -{ - int i; - for (i=0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int)((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, + int count, int step) { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) -{ - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, + stbi_uc const *pcb, stbi_uc const *pcr, + int count, int step) { + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); - __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); - __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); - __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); - __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); - __m128i xw = _mm_set1_epi16(255); // alpha channel - - for (; i+7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); - - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); - - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); - - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); - - // store - _mm_storeu_si128((__m128i *) (out + 0), o0); - _mm_storeu_si128((__m128i *) (out + 16), o1); - out += 32; - } - } + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *)(out + 0), o0); + _mm_storeu_si128((__m128i *)(out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); - int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); - int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); - int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); - - for (; i+7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); - - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); - - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); - - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8*4; - } - } -#endif - - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + cr * -stbi__float2fixed(0.71414f) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg *j) -{ - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) { + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg *j) -{ - stbi__free_jpeg_components(j, j->s->img_n, 0); +static void stbi__cleanup_jpeg(stbi__jpeg *j) { + stbi__free_jpeg_components(j, j->s->img_n, 0); } -typedef struct -{ - resample_row_func resample; - stbi_uc *line0,*line1; - int hs,vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +typedef struct { + resample_row_func resample; + stbi_uc *line0, *line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) -{ - unsigned int t = x*y + 128; - return (stbi_uc) ((t + (t >>8)) >> 8); -} - -static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) -{ - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe - - // validate req_comp - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } - - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; - - // resample and color-convert - { - int k; - unsigned int i,j; - stbi_uc *output; - stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; - - stbi__resample res_comp[4]; - - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs-1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; - else r->resample = stbi__resample_row_generic; +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { + unsigned int t = x * y + 128; + return (stbi_uc)((t + (t >> 8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, + int *comp, int req_comp) { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && + (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // resample and color-convert + { + int k; + unsigned int i, j; + stbi_uc *output; + stbi_uc *coutput[4] = {NULL, NULL, NULL, NULL}; + + stbi__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); } - // can't error after this so, this is safe - output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - // now go ahead and resample - for (j=0; j < z->s->img_y; ++j) { - stbi_uc *out = output + n * z->s->img_x * j; - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, - y_bot ? r->line1 : r->line0, - y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) + r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) + r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) + r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) + r->resample = z->resample_row_hv_2_kernel; + else + r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = + r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; } - } - if (n >= 3) { - stbi_uc *y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i=0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } - } else { - if (is_rgb) { - if (n == 1) - for (i=0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i=0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc *y = coutput[0]; - if (n == 1) - for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; - else - for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; } - } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { + *out++ = y[i]; + *out++ = 255; + } + } } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } -} - -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - unsigned char* result; - stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x,y,comp,req_comp); - STBI_FREE(j); - return result; -} - -static int stbi__jpeg_test(stbi__context *s) -{ - int r; - stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; -} - -static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) -{ - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind( j->s ); - return 0; - } - if (x) *x = j->s->img_x; - if (y) *y = j->s->img_y; - if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; -} - -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) -{ - int result; - stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) + *comp = + z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + unsigned char *result; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) { + int r; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) { + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind(j->s); + return 0; + } + if (x) + *x = j->s->img_x; + if (y) + *y = j->s->img_y; + if (comp) + *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) { + int result; + stbi__jpeg *j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -3977,83 +4416,81 @@ static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct -{ - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[288]; - stbi__uint16 value[288]; +typedef struct { + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[288]; + stbi__uint16 value[288]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) -{ - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); +stbi_inline static int stbi__bitreverse16(int n) { + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); return n; } -stbi_inline static int stbi__bit_reverse(int v, int bits) -{ - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16-bits); -} - -static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) -{ - int i,k=0; - int code, next_code[16], sizes[17]; - - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i=0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i=1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i=1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16) code; - z->firstsymbol[i] = (stbi__uint16) k; - code = (code + sizes[i]); - if (sizes[i]) - if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); - z->maxcode[i] = code << (16-i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i=0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); - z->size [c] = (stbi_uc ) s; - z->value[c] = (stbi__uint16) i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s],s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } - } - ++next_code[s]; +stbi_inline static int stbi__bit_reverse(int v, int bits) { + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16 - bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, + int num) { + int i, k = 0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i = 0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i = 1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i = 1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16)code; + z->firstsymbol[i] = (stbi__uint16)k; + code = (code + sizes[i]); + if (sizes[i]) + if (code - 1 >= (1 << i)) + return stbi__err("bad codelengths", "Corrupt PNG"); + z->maxcode[i] = code << (16 - i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i = 0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); + z->size[c] = (stbi_uc)s; + z->value[c] = (stbi__uint16)i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s], s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } } - } - return 1; + ++next_code[s]; + } + } + return 1; } // zlib-from-memory implementation for PNG reading @@ -4062,277 +4499,313 @@ static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int // we require PNG read all the IDATs and combine them into a single // memory buffer -typedef struct -{ - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; +typedef struct { + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; - char *zout; - char *zout_start; - char *zout_end; - int z_expandable; + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; - stbi__zhuffman z_length, z_distance; + stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf *z) -{ - return (z->zbuffer >= z->zbuffer_end); -} - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) -{ - return stbi__zeof(z) ? 0 : *z->zbuffer++; -} - -static void stbi__fill_bits(stbi__zbuf *z) -{ - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); +stbi_inline static int stbi__zeof(stbi__zbuf *z) { + return (z->zbuffer >= z->zbuffer_end); } -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) -{ - unsigned int k; - if (z->num_bits < n) stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) { + return stbi__zeof(z) ? 0 : *z->zbuffer++; } -static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s,k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s=STBI__ZFAST_BITS+1; ; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) return -1; // invalid code! - // code size is s, so: - b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= sizeof (z->size)) return -1; // some data was corrupt somewhere! - if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes -{ - char *q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); - cur = (unsigned int) (z->zout - z->zout_start); - limit = old_limit = (unsigned) (z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; +static void stbi__fill_bits(stbi__zbuf *z) { + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) { + unsigned int k; + if (z->num_bits < n) + stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s, k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s = STBI__ZFAST_BITS + 1;; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) + return -1; // invalid code! + // code size is s, so: + b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= sizeof(z->size)) + return -1; // some data was corrupt somewhere! + if (z->size[b] != s) + return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, + int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) + return stbi__err("output buffer limit", "Corrupt PNG"); + cur = (unsigned int)(z->zout - z->zout_start); + limit = old_limit = (unsigned)(z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned)n) + return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if (limit > UINT_MAX / 2) + return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) + return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; } static const int stbi__zlength_base[31] = { - 3,4,5,6,7,8,9,10,11,13, - 15,17,19,23,27,31,35,43,51,59, - 67,83,99,115,131,163,195,227,258,0,0 }; - -static const int stbi__zlength_extra[31]= -{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; - -static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, -257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; - -static const int stbi__zdist_extra[32] = -{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; - -static int stbi__parse_huffman_block(stbi__zbuf *a) -{ - char *zout = a->zout; - for(;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) return 0; - zout = a->zout; - } - *zout++ = (char) z; + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + +static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, + 4, 4, 5, 5, 5, 5, 0, 0, 0}; + +static const int stbi__zdist_base[32] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; + +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) { + char *zout = a->zout; + for (;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) + return stbi__err("bad huffman code", + "Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) + return 0; + zout = a->zout; + } + *zout++ = (char)z; + } else { + stbi_uc *p; + int len, dist; + if (z == 256) { + a->zout = zout; + return 1; + } + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) + len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0) + return stbi__err("bad huffman code", "Corrupt PNG"); + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) + dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) + return stbi__err("bad dist", "Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) + return 0; + zout = a->zout; + } + p = (stbi_uc *)(zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { + do + *zout++ = v; + while (--len); + } } else { - stbi_uc *p; - int len,dist; - if (z == 256) { - a->zout = zout; - return 1; - } - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) return 0; - zout = a->zout; - } - p = (stbi_uc *) (zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { do *zout++ = v; while (--len); } - } else { - if (len) { do *zout++ = *p++; while (--len); } - } + if (len) { + do + *zout++ = *p++; + while (--len); + } } - } -} - -static int stbi__compute_huffman_codes(stbi__zbuf *a) -{ - static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286+32+137];//padding for maximum single op - stbi_uc codelength_sizes[19]; - int i,n; - - int hlit = stbi__zreceive(a,5) + 257; - int hdist = stbi__zreceive(a,5) + 1; - int hclen = stbi__zreceive(a,4) + 4; - int ntot = hlit + hdist; - - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i=0; i < hclen; ++i) { - int s = stbi__zreceive(a,3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; - - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc) c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a,2)+3; - if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n-1]; - } else if (c == 17) { - c = stbi__zreceive(a,3)+3; - } else if (c == 18) { - c = stbi__zreceive(a,7)+11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); - } - if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes+n, fill, c); - n += c; + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) { + static const stbi_uc length_dezigzag[19] = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc codelength_sizes[19]; + int i, n; + + int hlit = stbi__zreceive(a, 5) + 257; + int hdist = stbi__zreceive(a, 5) + 1; + int hclen = stbi__zreceive(a, 4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i = 0; i < hclen; ++i) { + int s = stbi__zreceive(a, 3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) + return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc)c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a, 2) + 3; + if (n == 0) + return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n - 1]; + } else if (c == 17) { + c = stbi__zreceive(a, 3) + 3; + } else if (c == 18) { + c = stbi__zreceive(a, 7) + 11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); } - } - if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; - return 1; -} - -static int stbi__parse_uncompressed_block(stbi__zbuf *a) -{ - stbi_uc header[4]; - int len,nlen,k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; -} - -static int stbi__parse_zlib_header(stbi__zbuf *a) -{ - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; -} - -static const stbi_uc stbi__zdefault_length[288] = -{ - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 -}; -static const stbi_uc stbi__zdefault_distance[32] = -{ - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 -}; + if (ntot - n < c) + return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes + n, fill, c); + n += c; + } + } + if (n != ntot) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) + return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) { + stbi_uc header[4]; + int len, nlen, k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = + (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) + return stbi__err("zlib corrupt", "Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) + return stbi__err("zlib corrupt", "Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) + return stbi__err("read past buffer", "Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) + return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) { + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if ((cmf * 256 + flg) % 31 != 0) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if (flg & 32) + return stbi__err("no preset dict", + "Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) + return stbi__err("bad compression", + "Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[288] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; +static const stbi_uc stbi__zdefault_distance[32] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; /* Init algorithm: { @@ -4346,117 +4819,131 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) -{ - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a,1); - type = stbi__zreceive(a,2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) return 0; - } else if (type == 3) { - return 0; +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) { + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) + return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a, 1); + type = stbi__zreceive(a, 2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) + return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, 288)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) + return 0; } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , 288)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; - } else { - if (!stbi__compute_huffman_codes(a)) return 0; - } - if (!stbi__parse_huffman_block(a)) return 0; + if (!stbi__compute_huffman_codes(a)) + return 0; } - } while (!final); - return 1; -} - -static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) -{ - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; - - return stbi__parse_zlib(a, parse_header); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) -{ - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int) (a.zout - a.zout_start); - else - return -1; -} - -STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(16384); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer+len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int) (a.zout - a.zout_start); - else - return -1; + if (!stbi__parse_huffman_block(a)) + return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, + int parse_header) { + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, + int *outlen) { + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + char const *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int)(a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, + int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(16384); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int)(a.zout - a.zout_start); + else + return -1; } #endif @@ -4471,1083 +4958,1312 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct -{ - stbi__uint32 length; - stbi__uint32 type; +typedef struct { + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) -{ - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) { + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context *s) -{ - static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; - int i; - for (i=0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) { + static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + int i; + for (i = 0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) + return stbi__err("bad png sig", "Not a PNG"); + return 1; } -typedef struct -{ - stbi__context *s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct { + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; - enum { - STBI__F_none=0, - STBI__F_sub=1, - STBI__F_up=2, - STBI__F_avg=3, - STBI__F_paeth=4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none = 0, + STBI__F_sub = 1, + STBI__F_up = 2, + STBI__F_avg = 3, + STBI__F_paeth = 4, + // synthetic filters used for first scanline to avoid needing a dummy row of + // 0s + STBI__F_avg_first, + STBI__F_paeth_first }; -static stbi_uc first_row_filter[5] = -{ - STBI__F_none, - STBI__F_sub, - STBI__F_none, - STBI__F_avg_first, - STBI__F_paeth_first -}; +static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, + STBI__F_avg_first, STBI__F_paeth_first}; -static int stbi__paeth(int a, int b, int c) -{ - int p = a + b - c; - int pa = abs(p-a); - int pb = abs(p-b); - int pc = abs(p-c); - if (pa <= pb && pa <= pc) return a; - if (pb <= pc) return b; - return c; +static int stbi__paeth(int a, int b, int c) { + int p = a + b - c; + int pa = abs(p - a); + int pb = abs(p - b); + int pc = abs(p - c); + if (pa <= pb && pa <= pc) + return a; + if (pb <= pc) + return b; + return c; } -static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; +static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, + 0, 0, 0, 0x01}; // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) -{ - int bytes = (depth == 16? 2 : 1); - stbi__context *s = a->s; - stbi__uint32 i,j,stride = x*out_n*bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later - - int output_bytes = out_n*bytes; - int filter_bytes = img_n*bytes; - int width = x; - - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); - a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) return stbi__err("outofmem", "Out of memory"); - - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; - - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); - - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *prior; - int filter = *raw++; - - if (filter > 4) - return stbi__err("invalid filter","Corrupt PNG"); - - if (depth < 8) { - if (img_width_bytes > x) return stbi__err("invalid width","Corrupt PNG"); - cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - - // if first row, use special filter that doesn't sample previous row - if (j == 0) filter = first_row_filter[filter]; - - // handle first byte explicitly - for (k=0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none : cur[k] = raw[k]; break; - case STBI__F_sub : cur[k] = raw[k]; break; - case STBI__F_up : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break; - case STBI__F_avg : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break; - case STBI__F_paeth : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break; - case STBI__F_avg_first : cur[k] = raw[k]; break; - case STBI__F_paeth_first: cur[k] = raw[k]; break; - } +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, + stbi__uint32 raw_len, int out_n, + stbi__uint32 x, stbi__uint32 y, int depth, + int color) { + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i, j, stride = x * out_n * bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n * bytes; + int filter_bytes = img_n * bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); + a->out = (stbi_uc *)stbi__malloc_mad3( + x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) + return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) + return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on + // non-interlaced PNGs, but issue #276 reported a PNG in the wild that had + // extra data at the end (all zeros), so just check for raw_len < img_len + // always. + if (raw_len < img_len) + return stbi__err("not enough pixels", "Corrupt PNG"); + + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter", "Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) + return stbi__err("invalid width", "Corrupt PNG"); + cur += + x * out_n - img_width_bytes; // store output to the rightmost img_len + // bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = + cur - + stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) + filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k = 0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none: + cur[k] = raw[k]; + break; + case STBI__F_sub: + cur[k] = raw[k]; + break; + case STBI__F_up: + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); + break; + case STBI__F_paeth: + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); + break; + case STBI__F_avg_first: + cur[k] = raw[k]; + break; + case STBI__F_paeth_first: + cur[k] = raw[k]; + break; } + } - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes+1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1)*filter_bytes; - #define STBI__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: memcpy(cur, raw, nk); break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break; - } - #undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n+1 == out_n); - #define STBI__CASE(f) \ - case f: \ - for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break; - } - #undef STBI__CASE - - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride*j; // start at the beginning of the row again - for (i=0; i < x; ++i,cur+=output_bytes) { - cur[filter_bytes+1] = 255; - } - } + // this is a little gross, so that we don't switch per-pixel or + // per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1) * filter_bytes; +#define STBI__CASE(f) \ + case f: \ + for (k = 0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: + memcpy(cur, raw, nk); + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - filter_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - filter_bytes], prior[k], + prior[k - filter_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); + } + break; } - } - - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *in = a->out + stride*j + x*out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit - // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k=x*img_n; k >= 2; k-=2, ++in) { - *cur++ = scale * ((*in >> 4) ); - *cur++ = scale * ((*in ) & 0x0f); - } - if (k > 0) *cur++ = scale * ((*in >> 4) ); - } else if (depth == 2) { - for (k=x*img_n; k >= 4; k-=4, ++in) { - *cur++ = scale * ((*in >> 6) ); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in ) & 0x03); - } - if (k > 0) *cur++ = scale * ((*in >> 6) ); - if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k=x*img_n; k >= 8; k-=8, ++in) { - *cur++ = scale * ((*in >> 7) ); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in ) & 0x01); - } - if (k > 0) *cur++ = scale * ((*in >> 7) ); - if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride*j; - if (img_n == 1) { - for (q=x-1; q >= 0; --q) { - cur[q*2+1] = 255; - cur[q*2+0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q=x-1; q >= 0; --q) { - cur[q*4+3] = 255; - cur[q*4+2] = cur[q*3+2]; - cur[q*4+1] = cur[q*3+1]; - cur[q*4+0] = cur[q*3+0]; - } - } - } +#undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n + 1 == out_n); +#define STBI__CASE(f) \ + case f: \ + for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, \ + cur += output_bytes, prior += output_bytes) \ + for (k = 0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { + cur[k] = raw[k]; + } + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - output_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - output_bytes], prior[k], + prior[k - output_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); + } + break; + } +#undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride * j; // start at the beginning of the row again + for (i = 0; i < x; ++i, cur += output_bytes) { + cur[filter_bytes + 1] = 255; + } + } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *in = a->out + stride * j + x * out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common + // 8-bit path optimal at minimal cost for 1/2/4-bit png guarante byte + // alignment, if width is not multiple of 8/4/2 we'll decode dummy + // trailing data that will be skipped in the later loop + stbi_uc scale = (color == 0) + ? stbi__depth_scale_table[depth] + : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than + // desired. we can allocate enough data that this never writes out of + // memory, but it could also overwrite the next scanline. can it overwrite + // non-empty data on the next scanline? yes, consider 1-pixel-wide + // scanlines with 1-bit-per-pixel. so we need to explicitly clamp the + // final ones + + if (depth == 4) { + for (k = x * img_n; k >= 2; k -= 2, ++in) { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } else if (depth == 2) { + for (k = x * img_n; k >= 4; k -= 4, ++in) { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k = x * img_n; k >= 8; k -= 8, ++in) { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) + *cur++ = scale * ((*in >> 7)); + if (k > 1) + *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) + *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) + *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) + *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) + *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) + *cur++ = scale * ((*in >> 1) & 0x01); } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc *cur = a->out; - stbi__uint16 *cur16 = (stbi__uint16*)cur; - - for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) { - *cur16 = (cur[0] << 8) | cur[1]; + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride * j; + if (img_n == 1) { + for (q = x - 1; q >= 0; --q) { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q = x - 1; q >= 0; --q) { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } } - } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16 *)cur; + + for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, + stbi__uint32 image_data_len, int out_n, + int depth, int color, int interlaced) { + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, + a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + for (p = 0; p < 7; ++p) { + int xorig[] = {0, 4, 0, 2, 0, 1, 0}; + int yorig[] = {0, 0, 4, 0, 2, 0, 1}; + int xspc[] = {8, 8, 4, 4, 2, 2, 1}; + int yspc[] = {8, 8, 8, 4, 4, 2, 2}; + int i, j, x, y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, + y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j = 0; j < y; ++j) { + for (i = 0; i < x; ++i) { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, + a->out + (j * x + i) * out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - return 1; + return 1; } -static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) -{ - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc *final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - - // de-interlacing - final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - for (p=0; p < 7; ++p) { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; - int i,j,x,y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; - } - for (j=0; j < y; ++j) { - for (i=0; i < x; ++i) { - int out_y = j*yspc[p]+yorig[p]; - int out_x = i*xspc[p]+xorig[p]; - memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, - a->out + (j*x+i)*out_bytes, out_bytes); - } - } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - return 1; -} + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); -static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); - - if (out_n == 2) { - for (i=0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i=0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 *p = (stbi__uint16*) z->out; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], + int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16 *)z->out; - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) -{ - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; - - p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) return stbi__err("outofmem", "Out of memory"); - - // between here and free(out) below, exitting would leak - temp_out = p; - - if (pal_img_n == 3) { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p += 3; - } - } else { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p[3] = palette[n+3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, + int pal_img_n) { + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + + // between here and free(out) below, exitting would leak + temp_out = p; - STBI_NOTUSED(len); + if (pal_img_n == 3) { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p += 3; + } + } else { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p[3] = palette[n + 3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; - return 1; + STBI_NOTUSED(len); + + return 1; } static int stbi__unpremultiply_on_load = 0; static int stbi__de_iphone_flag = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) -{ - stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) -{ - stbi__de_iphone_flag = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { + stbi__de_iphone_flag = flag_true_if_should_convert; } -static void stbi__de_iphone(stbi__png *z) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - if (s->img_out_n == 3) { // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 3; +static void stbi__de_iphone(stbi__png *z) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i = 0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = (t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i=0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = ( t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } + } else { + // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a, b, c, d) \ + (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + \ + (unsigned)(d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) { + stbi_uc palette[1024], pal_img_n = 0; + stbi_uc has_trans = 0, tc[3] = {0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; + int first = 1, k, interlace = 0, color = 0, is_iphone = 0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) + return 0; + + if (scan == STBI__SCAN_type) + return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { + int comp, filter; + if (!first) + return stbi__err("multiple IHDR", "Corrupt PNG"); + first = 0; + if (c.length != 13) + return stbi__err("bad IHDR len", "Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + z->depth = stbi__get8(s); + if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && + z->depth != 16) + return stbi__err("1/2/4/8/16-bit only", + "PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); + if (color > 6) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3 && z->depth == 16) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3) + pal_img_n = 3; + else if (color & 1) + return stbi__err("bad ctype", "Corrupt PNG"); + comp = stbi__get8(s); + if (comp) + return stbi__err("bad comp method", "Corrupt PNG"); + filter = stbi__get8(s); + if (filter) + return stbi__err("bad filter method", "Corrupt PNG"); + interlace = stbi__get8(s); + if (interlace > 1) + return stbi__err("bad interlace method", "Corrupt PNG"); + if (!s->img_x || !s->img_y) + return stbi__err("0-pixel image", "Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) + return stbi__err("too large", "Image too large to decode"); + if (scan == STBI__SCAN_header) + return 1; } else { - // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) + return stbi__err("too large", "Corrupt PNG"); + // if SCAN_header, have to scan to see if we have a tRNS } - } -} - -#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + break; + } -static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) -{ - stbi_uc palette[1024], pal_img_n=0; - stbi_uc has_trans=0, tc[3]={0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; - int first=1,k,interlace=0, color=0, is_iphone=0; - stbi__context *s = z->s; - - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; - - if (!stbi__check_png_header(s)) return 0; - - if (scan == STBI__SCAN_type) return 1; - - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C','g','B','I'): - is_iphone = 1; - stbi__skip(s, c.length); - break; - case STBI__PNG_TYPE('I','H','D','R'): { - int comp,filter; - if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); - first = 0; - if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); - s->img_x = stbi__get32be(s); - s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); - comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); - filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); - interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); - if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); - if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); - if (scan == STBI__SCAN_header) return 1; - } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); - // if SCAN_header, have to scan to see if we have a tRNS - } - break; - } - - case STBI__PNG_TYPE('P','L','T','E'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); - pal_len = c.length / 3; - if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); - for (i=0; i < pal_len; ++i) { - palette[i*4+0] = stbi__get8(s); - palette[i*4+1] = stbi__get8(s); - palette[i*4+2] = stbi__get8(s); - palette[i*4+3] = 255; - } - break; - } - - case STBI__PNG_TYPE('t','R','N','S'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); - if (pal_img_n) { - if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } - if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); - if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); - pal_img_n = 4; - for (i=0; i < c.length; ++i) - palette[i*4+3] = stbi__get8(s); - } else { - if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); - if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); - has_trans = 1; - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } - } - break; - } - - case STBI__PNG_TYPE('I','D','A','T'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); - if (scan == STBI__SCAN_header) { s->img_n = pal_img_n; return 1; } - if ((int)(ioff + c.length) < (int)ioff) return 0; - if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc *p; - if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); - z->idata = p; - } - if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); - ioff += c.length; - break; - } - - case STBI__PNG_TYPE('I','E','N','D'): { - stbi__uint32 raw_len, bpl; - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) return 1; - if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); - // initial guess for decoded data size to avoid unnecessary reallocs - bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component - raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); - if (z->expanded == NULL) return 0; // zlib should set error - STBI_FREE(z->idata); z->idata = NULL; - if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n+1; - else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; - if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; - } - } - if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); - if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; - } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; - } - STBI_FREE(z->expanded); z->expanded = NULL; - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - return 1; - } - - default: - // if critical, fail - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if ((c.type & (1 << 29)) == 0) { - #ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); - #endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); - } - stbi__skip(s, c.length); - break; + case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256 * 3) + return stbi__err("invalid PLTE", "Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) + return stbi__err("invalid PLTE", "Corrupt PNG"); + for (i = 0; i < pal_len; ++i) { + palette[i * 4 + 0] = stbi__get8(s); + palette[i * 4 + 1] = stbi__get8(s); + palette[i * 4 + 2] = stbi__get8(s); + palette[i * 4 + 3] = 255; } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } -} + break; + } -static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) -{ - void *result=NULL; - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) return result; + case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) + return stbi__err("tRNS after IDAT", "Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { + s->img_n = 4; + return 1; + } + if (pal_len == 0) + return stbi__err("tRNS before PLTE", "Corrupt PNG"); + if (c.length > pal_len) + return stbi__err("bad tRNS len", "Corrupt PNG"); + pal_img_n = 4; + for (i = 0; i < c.length; ++i) + palette[i * 4 + 3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) + return stbi__err("tRNS with alpha", "Corrupt PNG"); + if (c.length != (stbi__uint32)s->img_n * 2) + return stbi__err("bad tRNS len", "Corrupt PNG"); + has_trans = 1; + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * + stbi__depth_scale_table[z->depth]; // non 8-bit images will + // be larger + } } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) *n = p->s->img_n; - } - STBI_FREE(p->out); p->out = NULL; - STBI_FREE(p->expanded); p->expanded = NULL; - STBI_FREE(p->idata); p->idata = NULL; - - return result; -} - -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi__png p; - p.s = s; - return stbi__do_png(&p, x,y,comp,req_comp, ri); -} - -static int stbi__png_test(stbi__context *s) -{ - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; -} + break; + } -static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) -{ - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind( p->s ); - return 0; - } - if (x) *x = p->s->img_x; - if (y) *y = p->s->img_y; - if (comp) *comp = p->s->img_n; - return 1; -} + case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) + return stbi__err("no PLTE", "Corrupt PNG"); + if (scan == STBI__SCAN_header) { + s->img_n = pal_img_n; + return 1; + } + if ((int)(ioff + c.length) < (int)ioff) + return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) + idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, + idata_limit); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata + ioff, c.length)) + return stbi__err("outofdata", "Corrupt PNG"); + ioff += c.length; + break; + } -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); -} + case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + stbi__uint32 raw_len, bpl; + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) + return 1; + if (z->idata == NULL) + return stbi__err("no IDAT", "Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag( + (char *)z->idata, ioff, raw_len, (int *)&raw_len, !is_iphone); + if (z->expanded == NULL) + return 0; // zlib should set error + STBI_FREE(z->idata); + z->idata = NULL; + if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || + has_trans) + s->img_out_n = s->img_n + 1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, + z->depth, color, interlace)) + return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) + return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) + return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) + s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); + z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } -static int stbi__png_is16(stbi__context *s) -{ - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; + default: + // if critical, fail + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { +#ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); +#endif + return stbi__err(invalid_chunk, + "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, + stbi__result_info *ri) { + void *result = NULL; + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", + "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) + return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) + *n = p->s->img_n; + } + STBI_FREE(p->out); + p->out = NULL; + STBI_FREE(p->expanded); + p->expanded = NULL; + STBI_FREE(p->idata); + p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi__png p; + p.s = s; + return stbi__do_png(&p, x, y, comp, req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) { + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) { + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind(p->s); + return 0; + } + if (x) + *x = p->s->img_x; + if (y) + *y = p->s->img_y; + if (comp) + *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) { + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) { + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context *s) -{ - int r; - int sz; - if (stbi__get8(s) != 'B') return 0; - if (stbi__get8(s) != 'M') return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; -} - -static int stbi__bmp_test(stbi__context *s) -{ - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test_raw(stbi__context *s) { + int r; + int sz; + if (stbi__get8(s) != 'B') + return 0; + if (stbi__get8(s) != 'M') + return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) { + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } - // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) -{ - int n=0; - if (z == 0) return -1; - if (z >= 0x10000) { n += 16; z >>= 16; } - if (z >= 0x00100) { n += 8; z >>= 8; } - if (z >= 0x00010) { n += 4; z >>= 4; } - if (z >= 0x00004) { n += 2; z >>= 2; } - if (z >= 0x00002) { n += 1;/* >>= 1;*/ } - return n; +static int stbi__high_bit(unsigned int z) { + int n = 0; + if (z == 0) + return -1; + if (z >= 0x10000) { + n += 16; + z >>= 16; + } + if (z >= 0x00100) { + n += 8; + z >>= 8; + } + if (z >= 0x00010) { + n += 4; + z >>= 4; + } + if (z >= 0x00004) { + n += 2; + z >>= 2; + } + if (z >= 0x00002) { + n += 1; /* >>= 1;*/ + } + return n; } -static int stbi__bitcount(unsigned int a) -{ - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) { + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) -{ - static unsigned int mul_table[9] = { +static int stbi__shiftsigned(unsigned int v, int shift, int bits) { + static unsigned int mul_table[9] = { 0, - 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, - 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0,0,1,0,2,4,6,0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8-bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; -} - -typedef struct -{ - int bpp, offset, hsz; - unsigned int mr,mg,mb,ma, all_a; - int extra_read; + 0xff /*0b11111111*/, + 0x55 /*0b01010101*/, + 0x49 /*0b01001001*/, + 0x11 /*0b00010001*/, + 0x21 /*0b00100001*/, + 0x41 /*0b01000001*/, + 0x81 /*0b10000001*/, + 0x01 /*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0, 0, 1, 0, 2, 4, 6, 0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8 - bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct { + int bpp, offset, hsz; + unsigned int mr, mg, mb, ma, all_a; + int extra_read; } stbi__bmp_data; -static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) -{ - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; - - if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); - - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - int i; - if (hsz != 108 && hsz != 124) +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) { + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') + return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) + return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) + return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) + return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) + return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha + // channel but it was all 0 + } else { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - stbi__get32le(s); // discard color space - for (i=0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); } - } - return (void *) 1; -} - - -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - unsigned int mr=0,mg=0,mb=0,ma=0, all_a; - stbi_uc pal[256][4]; - int psize=0,i,j,width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set - - flip_vertically = ((int) s->img_y) > 0; - s->img_y = abs((int) s->img_y); - - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; - - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - STBI_ASSERT(info.offset == s->callback_already_read + (int) (s->img_buffer - s->img_buffer_original)); - if (info.offset != s->callback_already_read + (s->img_buffer - s->buffer_start)) { - return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + stbi__get32le(s); // discard color space + for (i = 0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved } - } - - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert - - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); - - out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z=0; - if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } - for (i=0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) stbi__get8(s); - pal[i][3] = 255; + } + } + return (void *)1; +} + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; + stbi_uc pal[256][4]; + int psize = 0, i, j, width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int)s->img_y) > 0; + s->img_y = abs((int)s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + STBI_ASSERT(info.offset == + s->callback_already_read + + (int)(s->img_buffer - s->img_buffer_original)); + if (info.offset != + s->callback_already_read + (s->img_buffer - s->buffer_start)) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z = 0; + if (psize == 0 || psize > 256) { + STBI_FREE(out); + return stbi__errpuc("invalid", "Corrupt BMP"); + } + for (i = 0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) + stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - + psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) + width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) + width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) + width = s->img_x; + else { + STBI_FREE(out); + return stbi__errpuc("bad bpp", "Corrupt BMP"); + } + pad = (-width) & 3; + if (info.bpp == 1) { + for (j = 0; j < (int)s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i = 0; i < (int)s->img_x; ++i) { + int color = (v >> bit_offset) & 0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + if ((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) width = s->img_x; - else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } - pad = (-width)&3; - if (info.bpp == 1) { - for (j=0; j < (int) s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i=0; i < (int) s->img_x; ++i) { - int color = (v>>bit_offset)&0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - if((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j=0; j < (int) s->img_y; ++j) { - for (i=0; i < (int) s->img_x; i += 2) { - int v=stbi__get8(s),v2=0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - } - stbi__skip(s, pad); - } + } else { + for (j = 0; j < (int)s->img_y; ++j) { + for (i = 0; i < (int)s->img_x; i += 2) { + int v = stbi__get8(s), v2 = 0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + } + stbi__skip(s, pad); } - } else { - int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; - int z = 0; - int easy=0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) width = 3 * s->img_x; - else if (info.bpp == 16) width = 2*s->img_x; - else /* bpp = 32 and pad = 0 */ width=0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; + } + } else { + int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, + bcount = 0, acount = 0; + int z = 0; + int easy = 0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) + width = 3 * s->img_x; + else if (info.bpp == 16) + width = 2 * s->img_x; + else /* bpp = 32 and pad = 0 */ + width = 0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - if (!easy) { - if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr) - 7; + rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg) - 7; + gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb) - 7; + bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma) - 7; + acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - for (j=0; j < (int) s->img_y; ++j) { - if (easy) { - for (i=0; i < (int) s->img_x; ++i) { - unsigned char a; - out[z+2] = stbi__get8(s); - out[z+1] = stbi__get8(s); - out[z+0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i=0; i < (int) s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) out[z++] = STBI__BYTECAST(a); - } - } - stbi__skip(s, pad); + } + for (j = 0; j < (int)s->img_y; ++j) { + if (easy) { + for (i = 0; i < (int)s->img_x; ++i) { + unsigned char a; + out[z + 2] = stbi__get8(s); + out[z + 1] = stbi__get8(s); + out[z + 0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) + out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i = 0; i < (int)s->img_x; ++i) { + stbi__uint32 v = + (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) + out[z++] = STBI__BYTECAST(a); + } } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j=0; j < (int) s->img_y>>1; ++j) { - stbi_uc *p1 = out + j *s->img_x*target; - stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; - for (i=0; i < (int) s->img_x*target; ++i) { - t = p1[i]; p1[i] = p2[i]; p2[i] = t; - } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j = 0; j < (int)s->img_y >> 1; ++j) { + stbi_uc *p1 = out + j * s->img_x * target; + stbi_uc *p2 = out + (s->img_y - 1 - j) * s->img_x * target; + for (i = 0; i < (int)s->img_x * target; ++i) { + t = p1[i]; + p1[i] = p2[i]; + p2[i] = t; } - } + } + } - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; - return out; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + return out; } #endif @@ -5555,592 +6271,625 @@ static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) -{ - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) *is_rgb16 = 0; - switch(bits_per_pixel) { - case 8: return STBI_grey; - case 16: if(is_grey) return STBI_grey_alpha; - // fallthrough - case 15: if(is_rgb16) *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: return bits_per_pixel/8; - default: return 0; - } -} - -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) -{ - int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; - int sz, tga_colormap_type; - stbi__get8(s); // discard Offset - tga_colormap_type = stbi__get8(s); // colormap type - if( tga_colormap_type > 1 ) { - stbi__rewind(s); - return 0; // only RGB or indexed allowed - } - tga_image_type = stbi__get8(s); // image type - if ( tga_colormap_type == 1 ) { // colormapped (paletted) image - if (tga_image_type != 1 && tga_image_type != 9) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip image x and y origin - tga_colormap_bpp = sz; - } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { - stbi__rewind(s); - return 0; // only RGB or grey allowed, +/- RLE - } - stbi__skip(s,9); // skip colormap specification and image x/y origin - tga_colormap_bpp = 0; - } - tga_w = stbi__get16le(s); - if( tga_w < 1 ) { - stbi__rewind(s); - return 0; // test width +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int *is_rgb16) { + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) + *is_rgb16 = 0; + switch (bits_per_pixel) { + case 8: + return STBI_grey; + case 16: + if (is_grey) + return STBI_grey_alpha; + // fallthrough + case 15: + if (is_rgb16) + *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: + return bits_per_pixel / 8; + default: + return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) { + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, + tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if (tga_colormap_type > 1) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; } - tga_h = stbi__get16le(s); - if( tga_h < 1 ) { - stbi__rewind(s); - return 0; // test height + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__rewind(s); + return 0; } - tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits - if (tga_colormap_bpp != 0) { - if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { - // when using a colormap, tga_bits_per_pixel is the size of the indexes - // I don't think anything but 8 or 16bit indexes makes sense - stbi__rewind(s); - return 0; - } - tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); - } else { - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + stbi__skip(s, 4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ((tga_image_type != 2) && (tga_image_type != 3) && + (tga_image_type != 10) && (tga_image_type != 11)) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE } - if(!tga_comp) { + stbi__skip(s, 9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if (tga_w < 1) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if (tga_h < 1) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); return 0; } - if (x) *x = tga_w; - if (y) *y = tga_h; - if (comp) *comp = tga_comp; - return 1; // seems to have passed everything -} - -static int stbi__tga_test(stbi__context *s) -{ - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if ( tga_color_type == 1 ) { // colormapped (paletted) image - if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - stbi__skip(s,4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s,9); // skip colormap specification and image x/y origin - } - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp( + tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), + NULL); + } + if (!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) + *x = tga_w; + if (y) + *y = tga_h; + if (comp) + *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) { + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if (tga_color_type > 1) + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image + if (sz != 1 && sz != 9) + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin + } + if (stbi__get16le(s) < 1) + goto errorEnd; // test width + if (stbi__get16le(s) < 1) + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) + goto errorEnd; // for colormapped images, bpp is size of an index + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead + // of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) -{ - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255)/31); - out[1] = (stbi_uc)((g * 255)/31); - out[2] = (stbi_uc)((b * 255)/31); - - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. -} - -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16=0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char *tga_data; - unsigned char *tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO - - if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // do a tiny bit of precessing - if ( tga_image_type >= 8 ) - { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); - - // If I'm paletted, then I'll use the number of bits from the palette - if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - - if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - - // tga info - *x = tga_width; - *y = tga_height; - if (comp) *comp = tga_comp; - - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); - - tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); - - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset ); - - if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { - for (i=0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height -i - 1 : i; - stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if ( tga_indexed) - { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start ); - // load the palette - tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); - } - if (tga_rgb16) { - stbi_uc *pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i=0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc *out) { + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be + // swapped later + out[0] = (stbi_uc)((r * 255) / 31); + out[1] = (stbi_uc)((g * 255) / 31); + out[2] = (stbi_uc)((b * 255) / 31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16 = 0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused + // (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // do a tiny bit of precessing + if (tga_image_type >= 8) { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if (tga_indexed) + tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), + &tga_rgb16); + + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have + // ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) + *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = + (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) + return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset); + + if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { + for (i = 0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height - i - 1 : i; + stbi_uc *tga_row = tga_data + row * tga_width * tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if (tga_indexed) { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // load the data - for (i=0; i < tga_width * tga_height; ++i) - { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if ( tga_is_RLE ) - { - if ( RLE_count == 0 ) - { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if ( !RLE_repeating ) - { - read_next_pixel = 1; - } - } else - { - read_next_pixel = 1; - } - // OK, if I need to read a pixel, do it now - if ( read_next_pixel ) - { - // load however much data we did have - if ( tga_indexed ) - { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if ( pal_idx >= tga_palette_len ) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx+j]; - } - } else if(tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel - - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i*tga_comp+j] = raw_data[j]; - // in case we're in RLE mode, keep counting down - --RLE_count; + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start); + // load the palette + tga_palette = + (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); } - // do I need to invert the image? - if ( tga_inverted ) - { - for (j = 0; j*2 < tga_height; ++j) - { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) - { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } - } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i = 0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // clear my palette, if I had one - if ( tga_palette != NULL ) - { - STBI_FREE( tga_palette ); + } + // load the data + for (i = 0; i < tga_width * tga_height; ++i) { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if (tga_is_RLE) { + if (RLE_count == 0) { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if (!RLE_repeating) { + read_next_pixel = 1; + } + } else { + read_next_pixel = 1; } - } + // OK, if I need to read a pixel, do it now + if (read_next_pixel) { + // load however much data we did have + if (tga_indexed) { + // read in index, then perform the lookup + int pal_idx = + (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if (pal_idx >= tga_palette_len) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx + j]; + } + } else if (tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) - { - unsigned char* tga_pixel = tga_data; - for (i=0; i < tga_width * tga_height; ++i) - { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i * tga_comp + j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if (tga_inverted) { + for (j = 0; j * 2 < tga_height; ++j) { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } } - } + } + // clear my palette, if I had one + if (tga_palette != NULL) { + STBI_FREE(tga_palette); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) { + unsigned char *tga_pixel = tga_data; + for (i = 0; i < tga_width * tga_height; ++i) { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, + tga_height); - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = - tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = + tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; } #endif // ************************************************************************************************* -// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, +// tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s) -{ - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; -} - -static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) -{ - int count, nleft, len; - - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; - } +static int stbi__psd_test(stbi__context *s) { + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) { + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) + return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; } - } - - return 1; -} + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) + return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w, h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", + "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", + "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for + // other modes.) + stbi__skip(s, stbi__get32be(s)); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s)); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s)); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", + "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *)stbi__malloc(4 * w * h); + + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w * h; + + // Initialize the data to zero. + // memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes + // literally. Else if n is between -127 and -1 inclusive, copy the next + // byte -n+1 times. Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row + // in the data, which we're going to just skip. + stbi__skip(s, h * channelCount * 2); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out + channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w,h; - stbi_uc *out; - STBI_NOTUSED(ri); - - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); - - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - - // Skip 6 reserved bytes. - stbi__skip(s, 6 ); - - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); - - if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s,stbi__get32be(s) ); - - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s) ); - - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s) ); - - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); - - // Create the destination image. - - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *) stbi__malloc(4 * w*h); - - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w*h; - - // Initialize the data to zero. - //memset( out, 0, pixelCount * 4 ); - - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop - - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2 ); - - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc *p; - - p = out+channel; - if (channel >= channelCount) { - // Fill this channel with default data. + } else { + // We're at the raw image data. It's each channel in order (Red, Green, + // Blue, Alpha, ...) where each channel consists of an 8-bit (or 16-bit) + // value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out + channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16)stbi__get16be(s); + } else { + stbi_uc *p = out + channel; + if (bitdepth == 16) { // input bpc for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); - } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } - } + *p = (stbi_uc)(stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } } - - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc *p = out+channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } - } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16) stbi__get16be(s); - } else { - stbi_uc *p = out+channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc) (stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } - } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i = 0; i < w * h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *)out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); + pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); + pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); + } } - } - - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i=0; i < w*h; ++i) { - stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); - pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); - pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); - } - } - } else { - for (i=0; i < w*h; ++i) { - unsigned char *pixel = out + 4*i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); - pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); - pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); - } - } + } else { + for (i = 0; i < w * h; ++i) { + unsigned char *pixel = out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); + pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); + pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); + } } - } - - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - - if (comp) *comp = 4; - *y = h; - *x = w; - - return out; + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, + w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + if (comp) + *comp = 4; + *y = h; + *x = w; + + return out; } #endif @@ -6152,215 +6901,222 @@ static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context *s,const char *str) -{ - int i; - for (i=0; i<4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s, const char *str) { + int i; + for (i = 0; i < 4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context *s) -{ - int i; +static int stbi__pic_test_core(stbi__context *s) { + int i; - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) + return 0; - for(i=0;i<84;++i) - stbi__get8(s); + for (i = 0; i < 84; ++i) + stbi__get8(s); - if (!stbi__pic_is4(s,"PICT")) - return 0; + if (!stbi__pic_is4(s, "PICT")) + return 0; - return 1; + return 1; } -typedef struct -{ - stbi_uc size,type,channel; +typedef struct { + stbi_uc size, type, channel; } stbi__pic_packet; -static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) -{ - int mask=0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) { + int mask = 0x80, i; - for (i=0; i<4; ++i, mask>>=1) { - if (channel & mask) { - if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); - dest[i]=stbi__get8(s); - } - } + for (i = 0; i < 4; ++i, mask >>= 1) { + if (channel & mask) { + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "PIC file too short"); + dest[i] = stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) -{ - int mask=0x80,i; +static void stbi__copyval(int channel, stbi_uc *dest, const stbi_uc *src) { + int mask = 0x80, i; - for (i=0;i<4; ++i, mask>>=1) - if (channel&mask) - dest[i]=src[i]; + for (i = 0; i < 4; ++i, mask >>= 1) + if (channel & mask) + dest[i] = src[i]; } -static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) -{ - int act_comp=0,num_packets=0,y,chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s, int width, int height, + int *comp, stbi_uc *result) { + int act_comp = 0, num_packets = 0, y, chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data - // for the same channel in multiple packets. - do { - stbi__pic_packet *packet; + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return stbi__errpuc("bad format","too many packets"); + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return stbi__errpuc("bad format", "too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); - if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (reading packets)"); + if (packet->size != 8) + return stbi__errpuc("bad format", "packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for(y=0; ytype) { - default: - return stbi__errpuc("bad format","packet has bad compression type"); + switch (packet->type) { + default: + return stbi__errpuc("bad format", "packet has bad compression type"); - case 0: {//uncompressed - int x; + case 0: { // uncompressed + int x; - for(x=0;xchannel,dest)) - return 0; - break; - } + for (x = 0; x < width; ++x, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + break; + } - case 1://Pure RLE - { - int left=width, i; - - while (left>0) { - stbi_uc count,value[4]; - - count=stbi__get8(s); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - - if (count > left) - count = (stbi_uc) left; - - if (!stbi__readval(s,packet->channel,value)) return 0; - - for(i=0; ichannel,dest,value); - left -= count; - } - } - break; - - case 2: {//Mixed RLE - int left=width; - while (left>0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); - - if (count >= 128) { // Repeated - stbi_uc value[4]; - - if (count==128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file","scanline overrun"); - - if (!stbi__readval(s,packet->channel,value)) - return 0; - - for(i=0;ichannel,dest,value); - } else { // Raw - ++count; - if (count>left) return stbi__errpuc("bad file","scanline overrun"); - - for(i=0;ichannel,dest)) - return 0; - } - left-=count; - } - break; - } - } + case 1: // Pure RLE + { + int left = width, i; + + while (left > 0) { + stbi_uc count, value[4]; + + count = stbi__get8(s); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pure read count)"); + + if (count > left) + count = (stbi_uc)left; + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + left -= count; + } + } break; + + case 2: { // Mixed RLE + int left = width; + while (left > 0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", + "file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count == 128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + } else { // Raw + ++count; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + for (i = 0; i < count; ++i, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + } + left -= count; + } + break; } - } + } + } + } - return result; + return result; } -static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) -{ - stbi_uc *result; - int i, x,y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s, int *px, int *py, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *result; + int i, x, y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) comp = &internal_comp; + if (!comp) + comp = &internal_comp; - for (i=0; i<92; ++i) - stbi__get8(s); + for (i = 0; i < 92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) + return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); //skip `ratio' - stbi__get16be(s); //skip `fields' - stbi__get16be(s); //skip `pad' + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); - memset(result, 0xff, x*y*4); + // intermediate buffer is RGBA + result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); + memset(result, 0xff, x * y * 4); - if (!stbi__pic_load_core(s,x,y,comp, result)) { - STBI_FREE(result); - result=0; - } - *px = x; - *py = y; - if (req_comp == 0) req_comp = *comp; - result=stbi__convert_format(result,4,req_comp,x,y); + if (!stbi__pic_load_core(s, x, y, comp, result)) { + STBI_FREE(result); + result = 0; + } + *px = x; + *py = y; + if (req_comp == 0) + req_comp = *comp; + result = stbi__convert_format(result, 4, req_comp, x, y); - return result; + return result; } -static int stbi__pic_test(stbi__context *s) -{ - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) { + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6368,514 +7124,539 @@ static int stbi__pic_test(stbi__context *s) // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct -{ - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct { + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct -{ - int w,h; - stbi_uc *out; // output buffer (always 4 components) - stbi_uc *background; // The current "background" as far as a gif is concerned - stbi_uc *history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc *color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct { + int w, h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context *s) -{ - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') return 0; - if (stbi__get8(s) != 'a') return 0; - return 1; -} - -static int stbi__gif_test(stbi__context *s) -{ - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; -} - -static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) -{ - int i; - for (i=0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } -} +static int stbi__gif_test_raw(stbi__context *s) { + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') + return 0; + if (stbi__get8(s) != 'a') + return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) { + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], + int num_entries, int transp) { + int i; + for (i = 0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, + int is_info) { + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') + return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') + return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + + if (comp != 0) + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the + // comments + + if (is_info) + return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) { + stbi__gif *g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind(s); + return 0; + } + if (x) + *x = g->w; + if (y) + *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) { + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) + return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) { + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) + return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc)init_code; + g->codes[init_code].suffix = (stbi_uc)init_code; + } + + // support no starting clear code + avail = clear + 2; + oldcode = -1; + + len = 0; + for (;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32)stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s, len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } -static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) -{ - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - version = stbi__get8(s); - if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + p->prefix = (stbi__int16)oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; + stbi__out_gif_code(g, (stbi__uint16)code); - if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } - if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image +// doesn't support it two back is the image from two frames ago, used for a very +// specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, + int req_comp, stbi_uc *two_back) { + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour + // (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp, 0)) + return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *)stbi__malloc(4 * pcount); + g->background = (stbi_uc *)stbi__malloc(4 * pcount); + g->history = (stbi_uc *)stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); - if (is_info) return 1; + // image is treated as "transparent" at the start - ie, nothing overwrites + // the current background; background colour is only used for pixels that + // are not rendered first frame, after that "background" color refers to the + // color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, + 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, + pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the + // old background + } - if (g->flags & 0x80) - stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } - return 1; -} + // background is what out is after the undoing of the previou frame; + memcpy(g->background, g->out, 4 * g->w * g->h); + } + + // clear my history; + memset(g->history, 0x00, + g->w * g->h); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; -static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind( s ); - return 0; - } - if (x) *x = g->w; - if (y) *y = g->h; - STBI_FREE(g); - return 1; -} + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; -static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) -{ - stbi_uc *p, *c; - int idx; - - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); - - if (g->cur_y >= g->max_y) return; - - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; - - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; - - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + g->lflags = stbi__get8(s); - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; } - } -} -static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) -{ - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw *p; - - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc) init_code; - g->codes[init_code].suffix = (stbi_uc) init_code; - } - - // support no starting clear code - avail = clear+2; - oldcode = -1; - - len = 0; - for(;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; - } - --len; - bits |= (stbi__int32) stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s,len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), + g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *)g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *)g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } + o = stbi__process_gif_raster(s, g); + if (!o) + return NULL; - p->prefix = (stbi__int16) oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = + 255; // just in case it was made transparent, undo that; It will + // be reset next frame if need be; + memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); + } + } + } - stbi__out_gif_code(g, (stbi__uint16) code); + return o; + } - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = + 10 * stbi__get16le( + s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } - } -} - -// this function is designed to support animated gifs, although stb_image doesn't support it -// two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) -{ - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); - - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *) stbi__malloc(4 * pcount); - g->background = (stbi_uc *) stbi__malloc(4 * pcount); - g->history = (stbi_uc *) stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); - - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; - - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); } + break; + } - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); - } - } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); - } - } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers - // background is what out is after the undoing of the previou frame; - memcpy( g->background, g->out, 4 * g->w * g->h ); - } - - // clear my history; - memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame - - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { - stbi__int32 x, y, w, h; - stbi_uc *o; - - x = stbi__get16le(s); - y = stbi__get16le(s); - w = stbi__get16le(s); - h = stbi__get16le(s); - if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); - - g->line_size = g->w * 4; - g->start_x = x * 4; - g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; - - // if the width of the specified rectangle is 0, that means - // we may not see *any* pixels or the image is malformed; - // to make sure this is caught, move the current y down to - // max_y (which is what out_gif_code checks). - if (w == 0) - g->cur_y = g->max_y; - - g->lflags = stbi__get8(s); - - if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; - } else { - g->step = g->line_size; - g->parse = 0; - } + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp) { + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } - if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *) g->lpal; - } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *) g->pal; - } else - return stbi__errpuc("missing color table", "Corrupt GIF"); - - o = stbi__process_gif_raster(s, g); - if (!o) return NULL; - - // if this was the first frame, - pcount = g->w * g->h; - if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); - } - } - } + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = + (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + if (NULL == tmp) { + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + return stbi__errpuc("outofmem", "Out of memory"); + } else { + out = (stbi_uc *)tmp; + out_size = layers * stride; + } + + if (delays) { + *delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, + sizeof(int) * layers); + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc *)stbi__malloc(layers * stride); + out_size = layers * stride; + if (delays) { + *delays = (int *)stbi__malloc(layers * sizeof(int)); + delays_size = layers * sizeof(int); + } + } + memcpy(out + ((layers - 1) * stride), u, stride); + if (layers >= 2) { + two_back = out - 2 * stride; + } - return o; - } - - case 0x21: // Comment Extension. - { - int len; - int ext = stbi__get8(s); - if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } - } - while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); - } - break; - } + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); - case 0x3B: // gif stream termination code - return (stbi_uc *) s; // using '1' causes warning on some compilers + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - default: - return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } -} + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc *u = 0; - stbi_uc *out = 0; - stbi_uc *two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; - - if (out) { - void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); - if (NULL == tmp) { - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); - return stbi__errpuc("outofmem", "Out of memory"); - } - else { - out = (stbi_uc*) tmp; - out_size = layers * stride; - } - - if (delays) { - *delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc*)stbi__malloc( layers * stride ); - out_size = layers * stride; - if (delays) { - *delays = (int*) stbi__malloc( layers * sizeof(int) ); - delays_size = layers * sizeof(int); - } - } - memcpy( out + ((layers - 1) * stride), u, stride ); - if (layers >= 2) { - two_back = out - 2 * stride; - } +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); - if (delays) { - (*delays)[layers - 1U] = g.delay; - } - } - } while (u != 0); + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + return u; } -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); - - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } - - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); - - return u; -} - -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) -{ - return stbi__gif_info_raw(s,x,y,comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) { + return stbi__gif_info_raw(s, x, y, comp); } #endif @@ -6883,396 +7664,434 @@ static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context *s, const char *signature) -{ - int i; - for (i=0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; -} - -static int stbi__hdr_test(stbi__context* s) -{ - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if(!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; -} - -#define STBI__HDR_BUFLEN 1024 -static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) -{ - int len=0; - char c = '\0'; - - c = (char) stbi__get8(z); - - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN-1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) { + int i; + for (i = 0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context *s) { + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if (!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) { + int len = 0; + char c = '\0'; + + c = (char)stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN - 1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char)stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) { + if (input[3] != 0) { + float f1; + // Exponent + f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) + output[1] = 1; + if (req_comp == 4) + output[3] = 1; + } else { + switch (req_comp) { + case 4: + output[3] = 1; /* fallthrough */ + case 3: + output[0] = output[1] = output[2] = 0; + break; + case 2: + output[1] = 1; /* fallthrough */ + case 1: + output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1, c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s, buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && + strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) + return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int)strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) + *comp = 3; + if (req_comp == 0) + req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = + (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if (width < 8 || width >= 32768) { + // Read flat data + for (j = 0; j < height; ++j) { + for (i = 0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, + req_comp); } - c = (char) stbi__get8(z); - } - - buffer[len] = 0; - return buffer; -} + } + } else { + // Read RLE-encoded data + scanline = NULL; -static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) -{ - if ( input[3] != 0 ) { - float f1; - // Exponent - f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a + // decoded pixel (note this can't be a valid pixel--one of RGB must be + // >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc)c1; + rgbe[1] = (stbi_uc)c2; + rgbe[2] = (stbi_uc)len; + rgbe[3] = (stbi_uc)stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense } - if (req_comp == 2) output[1] = 1; - if (req_comp == 4) output[3] = 1; - } else { - switch (req_comp) { - case 4: output[3] = 1; /* fallthrough */ - case 3: output[0] = output[1] = output[2] = 0; - break; - case 2: output[1] = 1; /* fallthrough */ - case 1: output[0] = 0; - break; + len <<= 8; + len |= stbi__get8(s); + if (len != width) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - } -} - -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int width, height; - stbi_uc *scanline; - float *hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1,c2, z; - const char *headerToken; - STBI_NOTUSED(ri); - - // Check identifier - headerToken = stbi__hdr_gettoken(s,buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); - - // Parse header - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); - - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int) strtol(token, NULL, 10); - - if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - - *x = width; - *y = height; - - if (comp) *comp = 3; - if (req_comp == 0) req_comp = 3; - - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); - - // Read data - hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); - - // Load image data - // image data is stored as some number of sca - if ( width < 8 || width >= 32768) { - // Read flat data - for (j=0; j < height; ++j) { - for (i=0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } + if (scanline == NULL) { + scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } } - } else { - // Read RLE-encoded data - scanline = NULL; - - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc) c1; - rgbe[1] = (stbi_uc) c2; - rgbe[2] = (stbi_uc) len; - rgbe[3] = (stbi_uc) stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - if (scanline == NULL) { - scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - for (i=0; i < width; ++i) - stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } } - if (scanline) - STBI_FREE(scanline); - } - - return hdr_data; -} - -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int dummy; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (stbi__hdr_test(s) == 0) { - stbi__rewind( s ); - return 0; - } - - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) { - stbi__rewind( s ); - return 0; - } - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *y = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *x = (int) strtol(token, NULL, 10); - *comp = 3; - return 1; + for (i = 0; i < width; ++i) + stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, + scanline + i * 4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind(s); + return 0; + } + + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) { + stbi__rewind(s); + return 0; + } + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *y = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *x = (int)strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) -{ - void *p; - stbi__bmp_data info; - - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - stbi__rewind( s ); - if (p == NULL) - return 0; - if (x) *x = s->img_x; - if (y) *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) { + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + stbi__rewind(s); + if (p == NULL) + return 0; + if (x) + *x = s->img_x; + if (y) + *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) -{ - int channelCount, dummy, depth; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind( s ); - return 0; - } - *comp = 4; - return 1; -} - -static int stbi__psd_is16(stbi__context *s) -{ - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - (void) stbi__get32be(s); - (void) stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) { + int channelCount, dummy, depth; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind(s); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) { + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + (void)stbi__get32be(s); + (void)stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind(s); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) -{ - int act_comp=0,num_packets=0,chained,dummy; - stbi__pic_packet packets[10]; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { - stbi__rewind(s); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) { + int act_comp = 0, num_packets = 0, chained, dummy; + stbi__pic_packet packets[10]; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) return 0; - } - stbi__skip(s, 88); + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind( s); + if (stbi__at_eof(s)) { + stbi__rewind(s); return 0; - } - if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind( s ); + } + if (packet->size != 8) { + stbi__rewind(s); return 0; - } - - stbi__skip(s, 8); - - do { - stbi__pic_packet *packet; - - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return 0; - - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; - - if (stbi__at_eof(s)) { - stbi__rewind( s ); - return 0; - } - if (packet->size != 8) { - stbi__rewind( s ); - return 0; - } - } while (chained); + } + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); + *comp = (act_comp & 0x10 ? 4 : 3); - return 1; + return 1; } #endif @@ -7290,257 +8109,266 @@ static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s) -{ - char p, t; - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) { + char p, t; + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + return 1; } -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - STBI_NOTUSED(ri); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + STBI_NOTUSED(ri); - if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) - return 0; + if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) + return 0; - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; - if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "PNM too large"); + if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "PNM too large"); - out = (stbi_uc *) stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - stbi__getn(s, out, s->img_n * s->img_x * s->img_y); + out = (stbi_uc *)stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + stbi__getn(s, out, s->img_n * s->img_x * s->img_y); - if (req_comp && req_comp != s->img_n) { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - return out; + if (req_comp && req_comp != s->img_n) { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + return out; } -static int stbi__pnm_isspace(char c) -{ - return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +static int stbi__pnm_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; } -static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) -{ - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char) stbi__get8(s); +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) { + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char)stbi__get8(s); - if (stbi__at_eof(s) || *c != '#') - break; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') + *c = (char)stbi__get8(s); + } } -static int stbi__pnm_isdigit(char c) -{ - return c >= '0' && c <= '9'; +static int stbi__pnm_isdigit(char c) { + return c >= '0' && c <= '9'; } -static int stbi__pnm_getinteger(stbi__context *s, char *c) -{ - int value = 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) { + int value = 0; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value*10 + (*c - '0'); - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value * 10 + (*c - '0'); + *c = (char)stbi__get8(s); + } - return value; + return value; } -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) -{ - int maxv, dummy; - char c, p, t; +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) { + int maxv, dummy; + char c, p, t; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; - stbi__rewind(s); + stbi__rewind(s); - // Get identifier - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } + // Get identifier + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = + (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - c = (char) stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); + c = (char)stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width - stbi__pnm_skip_whitespace(s, &c); + *x = stbi__pnm_getinteger(s, &c); // read width + stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height - stbi__pnm_skip_whitespace(s, &c); + *y = stbi__pnm_getinteger(s, &c); // read height + stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value + maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 255) - return stbi__err("max value > 255", "PPM image not 8-bit"); - else - return 1; + if (maxv > 255) + return stbi__err("max value > 255", "PPM image not 8-bit"); + else + return 1; } #endif -static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) -{ - #ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) return 1; - #endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) { +#ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) + return 1; +#endif - // test tga last because it's a crappy test! - #ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; - #endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +// test tga last because it's a crappy test! +#ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; +#endif + return stbi__err("unknown image type", + "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context *s) -{ - #ifndef STBI_NO_PNG - if (stbi__png_is16(s)) return 1; - #endif +static int stbi__is_16_main(stbi__context *s) { +#ifndef STBI_NO_PNG + if (stbi__png_is16(s)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) + return 1; +#endif - return 0; + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_info_from_file(f, x, y, comp); - fclose(f); - return result; -} - -STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s,x,y,comp); - fseek(f,pos,SEEK_SET); - return r; -} - -STBIDEF int stbi_is_16_bit(char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_is_16_bit_from_file(f); - fclose(f); - return result; -} - -STBIDEF int stbi_is_16_bit_from_file(FILE *f) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f,pos,SEEK_SET); - return r; +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s, x, y, comp); + fseek(f, pos, SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f, pos, SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, + int *x, int *y, int *comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, + void *user) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION /* revision history: - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs - 2.19 (2018-02-11) fix warning - 2.18 (2018-01-30) fix warnings - 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and + platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix + warnings 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug 1-bit BMP *_is_16_bit api avoid warnings @@ -7555,13 +8383,11 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user warning fixes; disable run-time SSE detection on gcc; uniform handling of optional "return" values; thread-safe initialization of zlib tables - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) allocate large structures on the stack - remove white matting for transparent PSD - fix reported channel count for PNG & BMP - re-enable SSE2 in non-gcc 64-bit + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet + JPGs 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now 2.12 + (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 (2016-04-02) + allocate large structures on the stack remove white matting for transparent + PSD fix reported channel count for PNG & BMP re-enable SSE2 in non-gcc 64-bit support RGB-formatted JPEG read 16-bit PNGs (only as 8-bit) 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED @@ -7569,11 +8395,9 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 16-bit-per-pixel TGA (not bit-per-component) info() for TGA could break due to .hdr handling info() for BMP to shares code instead of sloppy parse - can use STBI_REALLOC_SIZED if allocator doesn't support realloc - code cleanup - 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA - 2.07 (2015-09-13) fix compiler warnings - partial animated GIF support + can use STBI_REALLOC_SIZED if allocator doesn't support + realloc code cleanup 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD + as RGBA 2.07 (2015-09-13) fix compiler warnings partial animated GIF support limited 16-bpc PSD support #ifdef unused functions bug with < 92 byte PIC,PNM,HDR,TGA @@ -7584,23 +8408,18 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user stbi_set_flip_vertically_on_load (nguillemot) fix NEON support; fix mingw support 2.02 (2015-01-19) fix incorrect assert, fix warning - 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 - 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG - 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) - progressive JPEG (stb) - PGM/PPM support (Ken Miller) - STBI_MALLOC,STBI_REALLOC,STBI_FREE + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit + without -msse2 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG 2.00 + (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) progressive + JPEG (stb) PGM/PPM support (Ken Miller) STBI_MALLOC,STBI_REALLOC,STBI_FREE GIF bugfix -- seemingly never worked STBI_NO_*, STBI_ONLY_* 1.48 (2014-12-14) fix incorrectly-named assert() - 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) - optimize PNG (ryg) - fix bug in interlaced PNG with user-specified channel count (stb) - 1.46 (2014-08-26) - fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG - 1.45 (2014-08-16) - fix MSVC-ARM internal compiler error by wrapping malloc - 1.44 (2014-08-07) + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar + Cornut & stb) optimize PNG (ryg) fix bug in interlaced PNG with + user-specified channel count (stb) 1.46 (2014-08-26) fix broken tRNS chunk + (colorkey-style transparency) in non-paletted PNG 1.45 (2014-08-16) fix + MSVC-ARM internal compiler error by wrapping malloc 1.44 (2014-08-07) various warning fixes from Ronny Chevalier 1.43 (2014-07-15) fix MSVC-only compiler problem in code changed in 1.42 @@ -7609,73 +8428,48 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user fixes to stbi__cleanup_jpeg path added STBI_ASSERT to avoid requiring assert.h 1.41 (2014-06-25) - fix search&replace from 1.36 that messed up comments/error messages - 1.40 (2014-06-22) - fix gcc struct-initialization warning - 1.39 (2014-06-15) - fix to TGA optimization when req_comp != number of components in TGA; - fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) - add support for BMP version 5 (more ignored fields) - 1.38 (2014-06-06) - suppress MSVC warnings on integer casts truncating values - fix accidental rename of 'skip' field of I/O - 1.37 (2014-06-04) - remove duplicate typedef - 1.36 (2014-06-03) - convert to header file single-file library - if de-iphone isn't set, load iphone images color-swapped instead of returning NULL - 1.35 (2014-05-27) - various warnings - fix broken STBI_SIMD path - fix bug where stbi_load_from_file no longer left file pointer in correct place - fix broken non-easy path for 32-bit BMP (possibly never used) - TGA optimization by Arseny Kapoulkine - 1.34 (unknown) - use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case - 1.33 (2011-07-14) - make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements - 1.32 (2011-07-13) - support for "info" function for all supported filetypes (SpartanJ) - 1.31 (2011-06-20) - a few more leak fixes, bug in PNG handling (SpartanJ) - 1.30 (2011-06-11) - added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + fix search&replace from 1.36 that messed up comments/error + messages 1.40 (2014-06-22) fix gcc struct-initialization warning 1.39 + (2014-06-15) fix to TGA optimization when req_comp != number of components in + TGA; fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my + test suite) add support for BMP version 5 (more ignored fields) 1.38 + (2014-06-06) suppress MSVC warnings on integer casts truncating values fix + accidental rename of 'skip' field of I/O 1.37 (2014-06-04) remove duplicate + typedef 1.36 (2014-06-03) convert to header file single-file library if + de-iphone isn't set, load iphone images color-swapped instead of returning + NULL 1.35 (2014-05-27) various warnings fix broken STBI_SIMD path fix bug + where stbi_load_from_file no longer left file pointer in correct place fix + broken non-easy path for 32-bit BMP (possibly never used) TGA optimization by + Arseny Kapoulkine 1.34 (unknown) use STBI_NOTUSED in + stbi__resample_row_generic(), fix one more leak in tga failure case 1.33 + (2011-07-14) make stbi_is_hdr work in STBI_NO_HDR (as specified), minor + compiler-friendly improvements 1.32 (2011-07-13) support for "info" function + for all supported filetypes (SpartanJ) 1.31 (2011-06-20) a few more leak + fixes, bug in PNG handling (SpartanJ) 1.30 (2011-06-11) added ability to + load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway - error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) - fix inefficiency in decoding 32-bit BMP (David Woo) - 1.29 (2010-08-16) - various warning fixes from Aurelien Pocheville - 1.28 (2010-08-01) - fix bug in GIF palette transparency (SpartanJ) - 1.27 (2010-08-01) - cast-to-stbi_uc to fix warnings - 1.26 (2010-07-24) - fix bug in file buffering for PNG reported by SpartanJ - 1.25 (2010-07-17) - refix trans_data warning (Won Chun) - 1.24 (2010-07-12) - perf improvements reading from files on platforms with lock-heavy fgetc() - minor perf improvements for jpeg - deprecated type-specific functions so we'll get feedback if they're needed - attempt to fix trans_data warning (Won Chun) - 1.23 fixed bug in iPhone support - 1.22 (2010-07-10) - removed image *writing* support - stbi_info support from Jetro Lauha - GIF support from Jean-Marc Lienher + removed support for installable file formats (stbi_loader) -- + would have been broken for IO callbacks anyway error cases in bmp and tga + give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in + decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from + Aurelien Pocheville 1.28 (2010-08-01) fix bug in GIF palette transparency + (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings 1.26 + (2010-07-24) fix bug in file buffering for PNG reported by SpartanJ 1.25 + (2010-07-17) refix trans_data warning (Won Chun) 1.24 (2010-07-12) perf + improvements reading from files on platforms with lock-heavy fgetc() minor + perf improvements for jpeg deprecated type-specific functions so we'll get + feedback if they're needed attempt to fix trans_data warning (Won Chun) 1.23 + fixed bug in iPhone support 1.22 (2010-07-10) removed image *writing* + support stbi_info support from Jetro Lauha GIF support from Jean-Marc Lienher iPhone PNG-extensions from James Brown - warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) - 1.21 fix use of 'stbi_uc' in header (reported by jon blow) - 1.20 added support for Softimage PIC, by Tom Seddon - 1.19 bug in interlaced PNG corruption check (found by ryg) - 1.18 (2008-08-02) - fix a threading bug (local mutable static) - 1.17 support interlaced PNG - 1.16 major bugfix - stbi__convert_format converted one too many pixels - 1.15 initialize some fields for thread safety - 1.14 fix threadsafe conversion bug - header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. + Janez (U+017D)emva) 1.21 fix use of 'stbi_uc' in header (reported by jon + blow) 1.20 added support for Softimage PIC, by Tom Seddon 1.19 bug in + interlaced PNG corruption check (found by ryg) 1.18 (2008-08-02) fix a + threading bug (local mutable static) 1.17 support interlaced PNG 1.16 + major bugfix - stbi__convert_format converted one too many pixels 1.15 + initialize some fields for thread safety 1.14 fix threadsafe conversion + bug header-file-only version (#define STBI_HEADER_FILE_ONLY before including) 1.13 threadsafe 1.12 const qualifiers in the API 1.11 Support installable IDCT, colorspace conversion routines @@ -7685,15 +8479,14 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz 1.07 attempt to fix C++ warning/errors again 1.06 attempt to fix C++ warning/errors again - 1.05 fix TGA loading to return correct *comp and use good luminance calc - 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free - 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR - 1.02 support for (subset of) HDR files, float interface for preferred access to them - 1.01 fix bug: possible bug in handling right-side up bmps... not sure - fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all - 1.00 interface to zlib that skips zlib header - 0.99 correct handling of alpha in palette - 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 1.05 fix TGA loading to return correct *comp and use good luminance + calc 1.04 default float alpha is 1, not 255; use 'void *' for + stbi_image_free 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR 1.02 support + for (subset of) HDR files, float interface for preferred access to them 1.01 + fix bug: possible bug in handling right-side up bmps... not sure fix bug: the + stbi__bmp_load() and stbi__tga_load() functions didn't work at all 1.00 + interface to zlib that skips zlib header 0.99 correct handling of alpha in + palette 0.98 TGA loader by lonesock; dynamically add loaders (untested) 0.97 jpeg errors on too large a file; also catch another malloc failure 0.96 fix detection of invalid v value - particleman@mollyrocket forum 0.95 during header scan, seek to markers in case of padding @@ -7706,8 +8499,8 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 0.60 fix compiling as c++ 0.59 fix warnings: merge Dave Moore's -Wall fixes 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian - 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available - 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but + less than 16 available 0.56 fix bug: zlib uncompressed mode len vs. nlen 0.55 fix bug: restart_interval not initialized to 0 0.54 allow NULL for 'int *comp' 0.53 fix bug in png 3->4; speedup png decoding @@ -7718,7 +8511,6 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user first released version */ - /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. @@ -7760,4 +8552,3 @@ ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ------------------------------------------------------------------------------ */ - diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/compile.sh b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/compile.sh index cad138fb..16209b85 100755 --- a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/compile.sh @@ -4,7 +4,7 @@ python3 lenet-mnist.py printf "\n[Compile Script]: Convert TF model to LLVM IR\n" python3 -m tf2onnx.convert --saved-model lenet-mnist.tf --output model.onnx python3 ../../../../tools/ExtendONNXModel.py --model_path ./model.onnx --output_model_path ./extendedmodel.onnx > expected_op_seq.txt -onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp +onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp mlir-translate -mlir-to-llvmir extendedmodel.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/image.c b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/image.c index aac7662f..9f1c5292 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/image.c +++ b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/image.c @@ -157,7 +157,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/lenet-mnist.py b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/lenet-mnist.py index f0ec3859..ca1da259 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/lenet-mnist.py +++ b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/lenet-mnist.py @@ -1,9 +1,11 @@ from tensorflow.keras import datasets, layers, models, losses + def process_images(images): - images = images.reshape((-1, 28, 28, 1)) - images = images / 255.0 - return images + images = images.reshape((-1, 28, 28, 1)) + images = images / 255.0 + return images + # from src import tensorfi2 as tfi (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data() @@ -11,21 +13,31 @@ def process_images(images): test_images = process_images(test_images) model = models.Sequential() -model.add(layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1))) +model.add( + layers.Conv2D( + filters=6, kernel_size=(5, 5), activation="relu", input_shape=(28, 28, 1) + ) +) model.add(layers.AveragePooling2D()) -model.add(layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu')) +model.add(layers.Conv2D(filters=16, kernel_size=(5, 5), activation="relu")) model.add(layers.AveragePooling2D()) model.add(layers.Flatten()) -model.add(layers.Dense(units=120, activation='relu')) -model.add(layers.Dense(units=84, activation='relu')) -model.add(layers.Dense(units=10, activation = 'softmax')) - -model.compile(optimizer='adam', - loss=losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=['accuracy']) +model.add(layers.Dense(units=120, activation="relu")) +model.add(layers.Dense(units=84, activation="relu")) +model.add(layers.Dense(units=10, activation="softmax")) -model.fit(train_images, train_labels, batch_size=100, epochs=10, - validation_data=(test_images, test_labels)) +model.compile( + optimizer="adam", + loss=losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=["accuracy"], +) -model.save('./lenet-mnist.tf') +model.fit( + train_images, + train_labels, + batch_size=100, + epochs=10, + validation_data=(test_images, test_labels), +) +model.save("./lenet-mnist.tf") diff --git a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/stb_image.h b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/stb_image.h index accef483..5b891039 100644 --- a/sample_programs/ml_sample_programs/vision_models/lenet-mnist/stb_image.h +++ b/sample_programs/ml_sample_programs/vision_models/lenet-mnist/stb_image.h @@ -3,7 +3,8 @@ Do this: #define STB_IMAGE_IMPLEMENTATION - before you include this file in *one* C or C++ file to create the implementation. + before you include this file in *one* C or C++ file to create the +implementation. // i.e. it should look like this: #include ... @@ -13,15 +14,16 @@ #include "stb_image.h" You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. - And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using +malloc,realloc,free QUICK NOTES: Primarily of interest to game developers and other people who can avoid problematic images and only need the trivial interface - JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) - PNG 1/2/4/8/16-bit-per-channel + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as +stock IJG lib) PNG 1/2/4/8/16-bit-per-channel TGA (not sure what subset, if a subset) BMP non-1bpp, non-RLE @@ -50,25 +52,22 @@ RECENT REVISION HISTORY: 2.26 (2020-07-13) many minor fixes 2.25 (2020-02-02) fix warnings - 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically - 2.23 (2019-08-11) fix clang static analysis warning - 2.22 (2019-03-04) gif fixes, fix warnings - 2.21 (2019-02-25) fix typo in comment - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and +flip_vertically 2.23 (2019-08-11) fix clang static analysis warning 2.22 +(2019-03-04) gif fixes, fix warnings 2.21 (2019-02-25) fix typo in comment 2.20 +(2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix warnings 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings - 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes - 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 - RGB-format JPEG; remove white matting in PSD; - allocate large structures on the stack; - correct channel count for PNG & BMP - 2.10 (2016-01-22) avoid warning introduced in 2.09 - 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; +bugfixes 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE +detection on GCC 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for +Imagenet JPGs 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; +fixes 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 +(2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 RGB-format JPEG; remove +white matting in PSD; allocate large structures on the stack; correct channel +count for PNG & BMP 2.10 (2016-01-22) avoid warning introduced in 2.09 2.09 +(2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED See end of file for full revision history. @@ -86,38 +85,37 @@ RECENT REVISION HISTORY: github:urraka (animated gif) Junggon Kim (PNM comments) Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) socks-the-fox (16-bit PNG) - Jeremy Sawicki (handle all ImageNet JPGs) - Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Jeremy Sawicki (handle all ImageNet +JPGs) Optimizations & bugfixes Mikhail Morozov (1-bit BMP) Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) Arseny Kapoulkine John-Mark Allen Carmelo J Fdez-Aguera Bug & warning fixes - Marc LeBlanc David Woo Guillaume George Martins Mozeiko - Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski - Phil Jordan Dave Moore Roy Eltham - Hayaki Saito Nathan Reed Won Chun - Luke Graham Johan Duparc Nick Verigakis the Horde3D community - Thomas Ruf Ronny Chevalier github:rlyeh - Janez Zemva John Bartholomew Michal Cichon github:romigrou - Jonathan Blow Ken Hamada Tero Hanninen github:svdijk - Laurent Gomila Cort Stratton github:snagar - Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex - Cass Everitt Ryamond Barbiero github:grim210 - Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw - Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus - Josh Tobin Matthew Gregan github:poppolopoppo - Julian Raschke Gregory Mullen Christian Floisand github:darealshinji - Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 - Brad Weinberger Matvey Cherevko [reserved] - Luca Sas Alexander Veselov Zack Middleton [reserved] + Marc LeBlanc David Woo Guillaume George Martins +Mozeiko Christpher Lloyd Jerry Jansson Joseph Thomson Blazej +Dariusz Roszkowski Phil Jordan Dave Moore Roy +Eltham Hayaki Saito Nathan Reed Won Chun Luke Graham Johan +Duparc Nick Verigakis the Horde3D community Thomas Ruf Ronny +Chevalier github:rlyeh Janez Zemva John +Bartholomew Michal Cichon github:romigrou Jonathan Blow Ken +Hamada Tero Hanninen github:svdijk Laurent Gomila Cort +Stratton github:snagar Aruelien Pocheville Sergio Gonzalez Thibault +Reuille github:Zelex Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Matthew Gregan +github:poppolopoppo Julian Raschke Gregory Mullen Christian +Floisand github:darealshinji Baldur Karlsson Kevin Schmidt JR +Smith github:Michaelangel007 Brad Weinberger Matvey Cherevko +[reserved] Luca Sas Alexander Veselov Zack Middleton [reserved] Ryan C. Gordon [reserved] [reserved] DO NOT ADD YOUR NAME HERE - To add your name to the credits, pick a random blank space in the middle and fill it. - 80% of merge conflicts on stb PRs are due to people adding their name at the end - of the credits. + To add your name to the credits, pick a random blank space in the middle and +fill it. 80% of merge conflicts on stb PRs are due to people adding their name +at the end of the credits. */ #ifndef STBI_INCLUDE_STB_IMAGE_H @@ -136,14 +134,15 @@ RECENT REVISION HISTORY: // // ... process data if not NULL ... // // ... x = width, y = height, n = # 8-bit components per pixel ... // // ... replace '0' with '1'..'4' to force that many components per pixel -// // ... but 'n' will always be the number that it would have been if you said 0 -// stbi_image_free(data) +// // ... but 'n' will always be the number that it would have been if you +// said 0 stbi_image_free(data) // // Standard parameters: // int *x -- outputs image width in pixels // int *y -- outputs image height in pixels // int *channels_in_file -- outputs # of image components in image file -// int desired_channels -- if non-zero, # of image components requested in result +// int desired_channels -- if non-zero, # of image components requested in +// result // // The return value from an image loader is an 'unsigned char *' which points // to the pixel data, or NULL on an allocation failure or if the image is @@ -171,8 +170,8 @@ RECENT REVISION HISTORY: // and *x, *y, *channels_in_file will be unchanged. The function // stbi_failure_reason() can be queried for an extremely brief, end-user // unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS -// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly -// more user-friendly ones. +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get +// slightly more user-friendly ones. // // Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. // @@ -196,11 +195,12 @@ RECENT REVISION HISTORY: // 2. easy to maintain // 3. good performance // -// Sometimes I let "good performance" creep up in priority over "easy to maintain", -// and for best performance I may provide less-easy-to-use APIs that give higher -// performance, in addition to the easy-to-use ones. Nevertheless, it's important -// to keep in mind that from the standpoint of you, a client of this library, -// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// Sometimes I let "good performance" creep up in priority over "easy to +// maintain", and for best performance I may provide less-easy-to-use APIs that +// give higher performance, in addition to the easy-to-use ones. Nevertheless, +// it's important to keep in mind that from the standpoint of you, a client of +// this library, all you care about is #1 and #3, and stb libraries DO NOT +// emphasize #3 above all. // // Some secondary priorities arise directly from the first two, some of which // provide more explicit reasons why performance can't be emphasized. @@ -219,7 +219,8 @@ RECENT REVISION HISTORY: // overhead. // // The three functions you must define are "read" (reads some bytes of data), -// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the +// end). // // =========================================================================== // @@ -247,10 +248,11 @@ RECENT REVISION HISTORY: // HDR image support (disable by defining STBI_NO_HDR) // // stb_image supports loading HDR images in general, and currently the Radiance -// .HDR file format specifically. You can still load any file through the existing -// interface; if you attempt to load an HDR file, it will be automatically remapped -// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; -// both of these constants can be reconfigured through this interface: +// .HDR file format specifically. You can still load any file through the +// existing interface; if you attempt to load an HDR file, it will be +// automatically remapped to LDR, assuming gamma 2.2 and an arbitrary scale +// factor defaulting to 1; both of these constants can be reconfigured through +// this interface: // // stbi_hdr_to_ldr_gamma(2.2f); // stbi_hdr_to_ldr_scale(1.0f); @@ -342,14 +344,13 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum -{ - STBI_default = 0, // only used for desired_channels +enum { + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -377,11 +378,13 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct -{ - int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof) (void *user); // returns nonzero if we are at end of file/data +typedef struct { + int (*read)(void *user, char *data, + int size); // fill 'data' with 'size' bytes. return number of + // bytes actually read + void (*skip)(void *user, int n); // skip the next 'n' bytes, or 'unget' the + // last -n bytes if negative + int (*eof)(void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -389,21 +392,33 @@ typedef struct // 8-bits-per-channel interface // -STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); -// for stbi_load_from_file, file pointer is left pointing immediately after image +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after +// image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input); #endif //////////////////////////////////// @@ -411,12 +426,20 @@ STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wch // 16-bits-per-channel interface // -STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); #endif //////////////////////////////////// @@ -424,83 +447,102 @@ STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_i // float-per-channel interface // #ifndef STBI_NO_LINEAR - STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); - #ifndef STBI_NO_STDIO - STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); - #endif +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +#endif #endif #ifndef STBI_NO_HDR - STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); - STBIDEF void stbi_hdr_to_ldr_scale(float scale); +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR - STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); - STBIDEF void stbi_ldr_to_hdr_scale(float scale); +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename); -STBIDEF int stbi_is_hdr_from_file(FILE *f); +STBIDEF int stbi_is_hdr(char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO - // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char *stbi_failure_reason (void); +STBIDEF const char *stbi_failure_reason(void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free (void *retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, + int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, + void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit (char const *filename); -STBIDEF int stbi_is_16_bit_from_file(FILE *f); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit(char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif - - // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); // indicate whether we should process iphone images back to canonical format, // or just pass them through "as-is" STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); -// flip the image vertically, so the first pixel in the output array is the bottom left +// flip the image vertically, so the first pixel in the output array is the +// bottom left STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); -// as above, but only applies to images loaded on the thread that calls the function -// this function is only available if your compiler supports thread-local variables; -// calling it will fail to link if your compiler doesn't -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); +// as above, but only applies to images loaded on the thread that calls the +// function this function is only available if your compiler supports +// thread-local variables; calling it will fail to link if your compiler doesn't +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); // ZLIB client - used by PNG, available for other purposes -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header); STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); - -STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, + int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); #ifdef __cplusplus } @@ -513,52 +555,53 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ - || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ - || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ - || defined(STBI_ONLY_ZLIB) - #ifndef STBI_ONLY_JPEG - #define STBI_NO_JPEG - #endif - #ifndef STBI_ONLY_PNG - #define STBI_NO_PNG - #endif - #ifndef STBI_ONLY_BMP - #define STBI_NO_BMP - #endif - #ifndef STBI_ONLY_PSD - #define STBI_NO_PSD - #endif - #ifndef STBI_ONLY_TGA - #define STBI_NO_TGA - #endif - #ifndef STBI_ONLY_GIF - #define STBI_NO_GIF - #endif - #ifndef STBI_ONLY_HDR - #define STBI_NO_HDR - #endif - #ifndef STBI_ONLY_PIC - #define STBI_NO_PIC - #endif - #ifndef STBI_ONLY_PNM - #define STBI_NO_PNM - #endif -#endif - -#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) -#define STBI_NO_ZLIB +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || \ + defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || \ + defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ + defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) +#ifndef STBI_ONLY_JPEG +#define STBI_NO_JPEG +#endif +#ifndef STBI_ONLY_PNG +#define STBI_NO_PNG +#endif +#ifndef STBI_ONLY_BMP +#define STBI_NO_BMP +#endif +#ifndef STBI_ONLY_PSD +#define STBI_NO_PSD +#endif +#ifndef STBI_ONLY_TGA +#define STBI_NO_TGA +#endif +#ifndef STBI_ONLY_GIF +#define STBI_NO_GIF +#endif +#ifndef STBI_ONLY_HDR +#define STBI_NO_HDR +#endif +#ifndef STBI_ONLY_PIC +#define STBI_NO_PIC +#endif +#ifndef STBI_ONLY_PNM +#define STBI_NO_PNM +#endif #endif +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && \ + !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif +#include #include #include // ptrdiff_t on osx #include #include -#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -576,55 +619,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #define STBI_EXTERN extern #endif - #ifndef _MSC_VER - #ifdef __cplusplus - #define stbi_inline inline - #else - #define stbi_inline - #endif +#ifdef __cplusplus +#define stbi_inline inline +#else +#define stbi_inline +#endif #else - #define stbi_inline __forceinline +#define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS - #if defined(__cplusplus) && __cplusplus >= 201103L - #define STBI_THREAD_LOCAL thread_local - #elif defined(__GNUC__) && __GNUC__ < 5 - #define STBI_THREAD_LOCAL __thread - #elif defined(_MSC_VER) - #define STBI_THREAD_LOCAL __declspec(thread) - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) - #define STBI_THREAD_LOCAL _Thread_local - #endif - - #ifndef STBI_THREAD_LOCAL - #if defined(__GNUC__) - #define STBI_THREAD_LOCAL __thread - #endif - #endif +#if defined(__cplusplus) && __cplusplus >= 201103L +#define STBI_THREAD_LOCAL thread_local +#elif defined(__GNUC__) && __GNUC__ < 5 +#define STBI_THREAD_LOCAL __thread +#elif defined(_MSC_VER) +#define STBI_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && \ + !defined(__STDC_NO_THREADS__) +#define STBI_THREAD_LOCAL _Thread_local +#endif + +#ifndef STBI_THREAD_LOCAL +#if defined(__GNUC__) +#define STBI_THREAD_LOCAL __thread +#endif +#endif #endif #ifdef _MSC_VER typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -632,27 +675,30 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL - #define stbi_lrot(x,y) _lrotl(x,y) +#define stbi_lrot(x, y) _lrotl(x, y) #else - #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (32 - (y)))) +#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (32 - (y)))) #endif -#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +#if defined(STBI_MALLOC) && defined(STBI_FREE) && \ + (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) // ok -#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && \ + !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) // ok #else -#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#error \ + "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p,newsz) realloc(p,newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p, newsz) realloc(p, newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) #endif // x86/x64 detection @@ -662,7 +708,8 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI__X86_TARGET #endif -#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && \ + !defined(STBI_NO_SIMD) // gcc doesn't support sse2 intrinsics unless you compile with -msse2, // which in turn means it gets to use SSE2 everywhere. This is unfortunate, // but previous attempts to provide the SSE2 functions with runtime @@ -673,8 +720,10 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI_NO_SIMD #endif -#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) -// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && \ + !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid +// STBI__X64_TARGET // // 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the // Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. @@ -684,44 +733,43 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; // See https://github.com/nothings/stb/issues/81 for more information. // // So default to no SSE2 on 32-bit MinGW. If you've read this far and added -// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +// -mstackrealign to your build settings, feel free to #define +// STBI_MINGW_ENABLE_SSE2. #define STBI_NO_SIMD #endif -#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#if !defined(STBI_NO_SIMD) && \ + (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) #define STBI_SSE2 #include #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) -{ - int info[4]; - __cpuid(info,1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) { + int info[4]; + __cpuid(info, 1); + return info[3]; } #else -static int stbi__cpuid3(void) -{ - int res; - __asm { +static int stbi__cpuid3(void) { + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) { + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -729,12 +777,11 @@ static int stbi__sse2_available(void) #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) { + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -766,188 +813,182 @@ static int stbi__sse2_available(void) // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct -{ - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct { + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void *io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; - static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) -{ - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) { + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) -{ - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, + void *user) { + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void *user, char *data, int size) -{ - return (int) fread(data,1,size,(FILE*) user); +static int stbi__stdio_read(void *user, char *data, int size) { + return (int)fread(data, 1, size, (FILE *)user); } -static void stbi__stdio_skip(void *user, int n) -{ - int ch; - fseek((FILE*) user, n, SEEK_CUR); - ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ - } +static void stbi__stdio_skip(void *user, int n) { + int ch; + fseek((FILE *)user, n, SEEK_CUR); + ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ + } } -static int stbi__stdio_eof(void *user) -{ - return feof((FILE*) user) || ferror((FILE *) user); +static int stbi__stdio_eof(void *user) { + return feof((FILE *)user) || ferror((FILE *)user); } -static stbi_io_callbacks stbi__stdio_callbacks = -{ - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static stbi_io_callbacks stbi__stdio_callbacks = { + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context *s, FILE *f) -{ - stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +static void stbi__start_file(stbi__context *s, FILE *f) { + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } -//static void stop_file(stbi__context *s) { } +// static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context *s) -{ - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 + // bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum -{ - STBI_ORDER_RGB, - STBI_ORDER_BGR -}; +enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; -typedef struct -{ - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct { + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context *s); -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context *s); -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__png_is16(stbi__context *s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context *s); -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context *s); -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s); -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__psd_is16(stbi__context *s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context *s); -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context *s); -static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context *s); -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s); -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); #endif static #ifdef STBI_THREAD_LOCAL -STBI_THREAD_LOCAL + STBI_THREAD_LOCAL #endif -const char *stbi__g_failure_reason; + const char *stbi__g_failure_reason; -STBIDEF const char *stbi_failure_reason(void) -{ - return stbi__g_failure_reason; +STBIDEF const char *stbi_failure_reason(void) { + return stbi__g_failure_reason; } #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char *str) -{ - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) { + stbi__g_failure_reason = str; + return 0; } #endif -static void *stbi__malloc(size_t size) -{ - return STBI_MALLOC(size); +static void *stbi__malloc(size_t size) { + return STBI_MALLOC(size); } // stb_image uses ints pervasively, including for offset calculations. @@ -962,70 +1003,72 @@ static void *stbi__malloc(size_t size) // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) -{ - if (b < 0) return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) { + if (b < 0) + return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) -{ - if (a < 0 || b < 0) return 0; - if (b == 0) return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX/b; +static int stbi__mul2sizes_valid(int a, int b) { + if (a < 0 || b < 0) + return 0; + if (b == 0) + return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; } -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__addsizes_valid(a*b*c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__addsizes_valid(a * b * c, add); } -// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't +// overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__mul2sizes_valid(a * b * c, d) && + stbi__addsizes_valid(a * b * c * d, add); } #endif -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void *stbi__malloc_mad2(int a, int b, int add) -{ - if (!stbi__mad2sizes_valid(a, b, add)) return NULL; - return stbi__malloc(a*b + add); +static void *stbi__malloc_mad2(int a, int b, int add) { + if (!stbi__mad2sizes_valid(a, b, add)) + return NULL; + return stbi__malloc(a * b + add); } #endif -static void *stbi__malloc_mad3(int a, int b, int c, int add) -{ - if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; - return stbi__malloc(a*b*c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) { + if (!stbi__mad3sizes_valid(a, b, c, add)) + return NULL; + return stbi__malloc(a * b * c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) -{ - if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; - return stbi__malloc(a*b*c*d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) { + if (!stbi__mad4sizes_valid(a, b, c, d, add)) + return NULL; + return stbi__malloc(a * b * c * d + add); } #endif @@ -1034,417 +1077,459 @@ static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS - #define stbi__err(x,y) 0 +#define stbi__err(x, y) 0 #elif defined(STBI_FAILURE_USERMSG) - #define stbi__err(x,y) stbi__err(y) +#define stbi__err(x, y) stbi__err(y) #else - #define stbi__err(x,y) stbi__err(x) +#define stbi__err(x, y) stbi__err(x) #endif -#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) -#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpuc(x, y) \ + ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -STBIDEF void stbi_image_free(void *retval_from_stbi_load) -{ - STBI_FREE(retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load) { + STBI_FREE(retval_from_stbi_load); } #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else -static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, + stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ - ? stbi__vertically_flip_on_load_local \ - : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load \ + (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; - - #ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNG - if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_GIF - if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PSD - if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); - #else - STBI_NOTUSED(bpc); - #endif - #ifndef STBI_NO_PIC - if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); - #endif - - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } - #endif - - #ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s,x,y,comp,req_comp, ri); - #endif - - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); -} - -static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi_uc *reduced; - - reduced = (stbi_uc *) stbi__malloc(img_len); - if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - - STBI_FREE(orig); - return reduced; -} - -static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi__uint16 *enlarged; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = + 8; // default is 8 so most paths don't have to be changed + ri->channel_order = + STBI_ORDER_RGB; // all current input & output are this, but this is here + // so we can add BGR order + ri->num_channels = 0; - enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); - if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); +#ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) + return stbi__jpeg_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNG + if (stbi__png_test(s)) + return stbi__png_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) + return stbi__bmp_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_GIF + if (stbi__gif_test(s)) + return stbi__gif_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PSD + if (stbi__psd_test(s)) + return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); +#else + STBI_NOTUSED(bpc); +#endif +#ifndef STBI_NO_PIC + if (stbi__pic_test(s)) + return stbi__pic_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) + return stbi__pnm_load(s, x, y, comp, req_comp, ri); +#endif - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } +#endif - STBI_FREE(orig); - return enlarged; -} +#ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s, x, y, comp, req_comp, ri); +#endif -static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) -{ - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc *bytes = (stbi_uc *)image; - - for (row = 0; row < (h>>1); row++) { - stbi_uc *row0 = bytes + row*bytes_per_row; - stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } + return stbi__errpuc("unknown image type", + "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *)stbi__malloc(img_len); + if (reduced == NULL) + return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = + (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient + // approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); + if (enlarged == NULL) + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + + orig[i]); // replicate to high and low byte, + // maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, + int bytes_per_pixel) { + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h >> 1); row++) { + stbi_uc *row0 = bytes + row * bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1) * bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = + (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) -{ - int slice; - int slice_size = w * h * bytes_per_pixel; +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, + int bytes_per_pixel) { + int slice; + int slice_size = w * h * bytes_per_pixel; - stbi_uc *bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *) result; + return (unsigned char *)result; } -static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to + // keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *) result; + return (stbi__uint16 *)result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) -{ - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, + int req_comp) { + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar( + unsigned int cp, unsigned long flags, const char *str, int cbmb, + wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte( + unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, + char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) -{ - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input) { + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, + (int)bufferlen, NULL, NULL); } #endif -static FILE *stbi__fopen(char const *filename, char const *mode) -{ - FILE *f; +static FILE *stbi__fopen(char const *filename, char const *mode) { + FILE *f; #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, + sizeof(wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) - return 0; + if (0 == + MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) + return 0; #if _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f=0; + if (0 != fopen_s(&f, filename, mode)) + f = 0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; -} - - -STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - unsigned char *result; - if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; -} - -STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__uint16 *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - stbi__uint16 *result; - if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f,x,y,comp,req_comp); - fclose(f); - return result; -} - - -#endif //!STBI_NO_STDIO - -STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); -} - -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + return f; +} + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) + return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) + return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +#endif //! STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_mem(&s,buffer,len); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_mem(&s, buffer, len); - result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); - } + result = + (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices(result, *x, *y, *z, *comp); + } - return result; + return result; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *data; - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data,x,y,comp,req_comp); - return hdr_data; - } - #endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); -} - -STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__loadf_main(&s,x,y,comp,req_comp); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp) { + unsigned char *data; +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data, x, y, comp, req_comp); + return hdr_data; + } +#endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", + "Image not of any known type, or corrupt"); } -STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__loadf_main(&s,x,y,comp,req_comp); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -#ifndef STBI_NO_STDIO -STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - float *result; - FILE *f = stbi__fopen(filename, "rb"); - if (!f) return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_file(&s,f); - return stbi__loadf_main(&s,x,y,comp,req_comp); +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, + int req_comp) { + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) + return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_file(&s, f); + return stbi__loadf_main(&s, x, y, comp, req_comp); } #endif // !STBI_NO_STDIO @@ -1454,221 +1539,222 @@ STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_ // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; +#endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result=0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result = 0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE *f) -{ - #ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s,f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; - #else - STBI_NOTUSED(f); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) { +#ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s, f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; +#else + STBI_NOTUSED(f); + return 0; +#endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; +#endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; +static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { + stbi__l2h_gamma = gamma; +} +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { + stbi__l2h_scale = scale; +} #endif -static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; - -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } +static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { + stbi__h2l_gamma_i = 1 / gamma; +} +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { + stbi__h2l_scale_i = 1 / scale; +} ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum -{ - STBI__SCAN_load=0, - STBI__SCAN_type, - STBI__SCAN_header -}; - -static void stbi__refill_buffer(stbi__context *s) -{ - int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); - s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start+1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } -} - -stbi_inline static stbi_uc stbi__get8(stbi__context *s) -{ - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; -} - -#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; + +static void stbi__refill_buffer(stbi__context *s) { + int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context *s) -{ - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) + return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) + return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context *s, int n) -{ - if (n == 0) return; // already there! - if (n < 0) { +static void stbi__skip(stbi__context *s, int n) { + if (n == 0) + return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); return; - } - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; + } + } + s->img_buffer += n; } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && \ + defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) -{ - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) { + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; - memcpy(buffer, s->img_buffer, blen); + memcpy(buffer, s->img_buffer, blen); - count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); - res = (count == (n-blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } + count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); + res = (count == (n - blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } - if (s->img_buffer+n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + if (s->img_buffer + n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context *s) -{ - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) { + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context *s) -{ - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) { + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context *s) -{ - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) { + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context *s) -{ - stbi__uint32 z = stbi__get16le(s); - return z + (stbi__get16le(s) << 16); +static stbi__uint32 stbi__get32le(stbi__context *s) { + stbi__uint32 z = stbi__get16le(s); + return z + (stbi__get16le(s) << 16); } #endif -#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) \ + ((stbi_uc)((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1682,169 +1768,301 @@ static stbi__uint32 stbi__get32le(stbi__context *s) // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) -{ - return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi_uc stbi__compute_y(int r, int g, int b) { + return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - unsigned char *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - unsigned char *src = data + j * x * img_n ; - unsigned char *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + unsigned char *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + unsigned char *src = data + j * x * img_n; + unsigned char *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 255; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 255; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 255; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = 255; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) -{ - return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { + return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - stbi__uint16 *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - stbi__uint16 *src = data + j * x * img_n ; - stbi__uint16 *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + stbi__uint16 *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + stbi__uint16 *src = data + j * x * img_n; + stbi__uint16 *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 0xffff; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 0xffff; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 0xffff; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = 0xffff; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return (stbi__uint16 *)stbi__errpuc("unsupported", + "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) -{ - int i,k,n; - float *output; - if (!data) return NULL; - output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i=0; i < x*y; ++i) { - output[i*comp + n] = data[i*comp + n]/255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) { + int i, k, n; + float *output; + if (!data) + return NULL; + output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpf("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + output[i * comp + k] = + (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * + stbi__l2h_scale); + } + } + if (n < comp) { + for (i = 0; i < x * y; ++i) { + output[i * comp + n] = data[i * comp + n] / 255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int) (x)) -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) -{ - int i,k,n; - stbi_uc *output; - if (!data) return NULL; - output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - if (k < comp) { - float z = data[i*comp+k] * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int)(x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) { + int i, k, n; + stbi_uc *output; + if (!data) + return NULL; + output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, + stbi__h2l_gamma_i) * + 255 + + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + if (k < comp) { + float z = data[i * comp + k] * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1872,750 +2090,791 @@ static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache - -typedef struct -{ - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct { + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct -{ - stbi__context *s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; - -// sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; - -// definition of jpeg image component - struct - { - int id; - int h,v; - int tq; - int hd,ha; - int dc_pred; - - int x,y,w2,h2; - stbi_uc *data; - void *raw_data, *raw_coeff; - stbi_uc *linebuf; - short *coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; - - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop - - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; - - int scan_n, order[4]; - int restart_interval, todo; - -// kernels - void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); - stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +typedef struct { + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, int count, + int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman *h, int *count) -{ - int i,j,k=0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i=0; i < 16; ++i) - for (j=0; j < count[i]; ++j) - h->size[k++] = (stbi_uc) (i+1); - h->size[k] = 0; - - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for(j=1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16) (code++); - if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16-j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; - - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i=0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS-s); - int m = 1 << (FAST_BITS-s); - for (j=0; j < m; ++j) { - h->fast[c+j] = (stbi_uc) i; - } +static int stbi__build_huffman(stbi__huffman *h, int *count) { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) + for (j = 0; j < count[i]; ++j) + h->size[k++] = (stbi_uc)(i + 1); + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16)(code++); + if (code - 1 >= (1u << j)) + return stbi__err("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (stbi_uc)i; } - } - return 1; + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) -{ - int i; - for (i=0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; - - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); - } +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) + k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); } - } -} - -static void stbi__grow_buffer_unsafe(stbi__jpeg *j) -{ - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char) c; - j->nomore = 1; - return; - } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) { + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) + c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; +static const stbi__uint32 stbi__bmask[17] = { + 0, 1, 3, 7, 15, 31, 63, 127, 255, + 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) -{ - unsigned int temp; - int c,k; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } - - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k=FAST_BITS+1 ; ; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } - - if (k > j->code_bits) +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) return -1; - - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1;; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & + stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); - - sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB - k = stbi_lrot(j->code_buffer, n); - if (n < 0 || n >= (int) (sizeof(stbi__bmask)/sizeof(*stbi__bmask))) return 0; - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & ~sgn); +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) { + unsigned int k; + int sgn; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + + sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB + k = stbi_lrot(j->code_buffer, n); + if (n < 0 || n >= (int)(sizeof(stbi__bmask) / sizeof(*stbi__bmask))) + return 0; + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & ~sgn); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) -{ - unsigned int k; - if (j->code_bits < n) stbi__grow_buffer_unsafe(j); - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; -} - -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) -{ - unsigned int k; - if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) { + unsigned int k; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) { + unsigned int k; + if (j->code_bits < 1) + stbi__grow_buffer_unsafe(j); + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64+15] = -{ - 0, 1, 8, 16, 9, 2, 3, 10, - 17, 24, 32, 25, 18, 11, 4, 5, - 12, 19, 26, 33, 40, 48, 41, 34, - 27, 20, 13, 6, 7, 14, 21, 28, - 35, 42, 49, 56, 57, 50, 43, 36, - 29, 22, 15, 23, 30, 37, 44, 51, - 58, 59, 52, 45, 38, 31, 39, 46, - 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, - 63, 63, 63, 63, 63, 63, 63 -}; +static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { + 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, + 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, + 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) -{ - int diff,dc,k; - int t; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - - // 0 all the ac values now so we can do it 32-bits at a time - memset(data,0,64*sizeof(data[0])); - - diff = t ? stbi__extend_receive(j, t) : 0; - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc * dequant[0]); - - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) * dequant[zig]); +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, stbi__huffman *hac, + stbi__int16 *fac, int b, + stbi__uint16 *dequant) { + int diff, dc, k; + int t; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) + break; // end block + k += 16; } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (rs != 0xf0) break; // end block - k += 16; - } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); - } + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); } - } while (k < 64); - return 1; -} - -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) -{ - int diff,dc; - int t; - if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data,0,64*sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t == -1) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; - - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc << j->succ_low); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short) (1 << j->succ_low); - } - return 1; + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, int b) { + int diff, dc; + int t; + if (j->spec_end != 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t == -1) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc << j->succ_low); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; } // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) -{ - int k; - if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->succ_high == 0) { - int shift = j->succ_low; +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], + stbi__huffman *hac, + stbi__int16 *fac) { + int k; + if (j->spec_start == 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } - if (j->eob_run) { - --j->eob_run; - return 1; + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) << shift); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) << shift); + } } - + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { k = j->spec_start; do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) << shift); - } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) << shift); - } - } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients - - short bit = (short) (1 << j->succ_low); - - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short *p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } - } else { - k = j->spec_start; - do { - int r,s; - int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } + int r, s; + int rs = stbi__jpeg_huff_decode( + j, hac); // @OPTIMIZE see if we can use the fast path here, + // advance-by-r is so slow, eh + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) + return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } - // advance by r - while (k <= j->spec_end) { - short *p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short) s; - break; - } - --r; - } + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short)s; + break; } - } while (k <= j->spec_end); - } - } - return 1; + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) -{ - // trick to use a single test to catch both cases - if ((unsigned int) x > 255) { - if (x < 0) return 0; - if (x > 255) return 255; - } - return (stbi_uc) x; +stbi_inline static stbi_uc stbi__clamp(int x) { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) + return 0; + if (x > 255) + return 255; + } + return (stbi_uc)x; } -#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) -#define stbi__fsh(x) ((x) * 4096) +#define stbi__f2f(x) ((int)(((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ - int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2+p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3*stbi__f2f(-1.847759065f); \ - t3 = p1 + p2*stbi__f2f( 0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2+p3); \ - t1 = stbi__fsh(p2-p3); \ - x0 = t0+t3; \ - x3 = t0-t3; \ - x1 = t1+t2; \ - x2 = t1-t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0+t2; \ - p4 = t1+t3; \ - p1 = t0+t3; \ - p2 = t1+t2; \ - p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ - t0 = t0*stbi__f2f( 0.298631336f); \ - t1 = t1*stbi__f2f( 2.053119869f); \ - t2 = t2*stbi__f2f( 3.072711026f); \ - t3 = t3*stbi__f2f( 1.501321110f); \ - p1 = p5 + p1*stbi__f2f(-0.899976223f); \ - p2 = p5 + p2*stbi__f2f(-2.562915447f); \ - p3 = p3*stbi__f2f(-1.961570560f); \ - p4 = p4*stbi__f2f(-0.390180644f); \ - t3 += p1+p4; \ - t2 += p2+p3; \ - t1 += p2+p4; \ - t0 += p1+p3; - -static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) -{ - int i,val[64],*v=val; - stbi_uc *o; - short *d = data; - - // columns - for (i=0; i < 8; ++i,++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 - && d[40]==0 && d[48]==0 && d[56]==0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0]*4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; x1 += 512; x2 += 512; x3 += 512; - v[ 0] = (x0+t3) >> 10; - v[56] = (x0-t3) >> 10; - v[ 8] = (x1+t2) >> 10; - v[48] = (x1-t2) >> 10; - v[16] = (x2+t1) >> 10; - v[40] = (x2-t1) >> 10; - v[24] = (x3+t0) >> 10; - v[32] = (x3-t0) >> 10; - } - } - - for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128<<17); - x1 += 65536 + (128<<17); - x2 += 65536 + (128<<17); - x3 += 65536 + (128<<17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0+t3) >> 17); - o[7] = stbi__clamp((x0-t3) >> 17); - o[1] = stbi__clamp((x1+t2) >> 17); - o[6] = stbi__clamp((x1-t2) >> 17); - o[2] = stbi__clamp((x2+t1) >> 17); - o[5] = stbi__clamp((x2-t1) >> 17); - o[3] = stbi__clamp((x3+t0) >> 17); - o[4] = stbi__clamp((x3-t0) >> 17); - } +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ + t0 += p1 + p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) { + int i, val[64], *v = val; + stbi_uc *o; + short *d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && + d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; + x1 += 512; + x2 += 512; + x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0 + t3) >> 17); + o[7] = stbi__clamp((x0 - t3) >> 17); + o[1] = stbi__clamp((x1 + t2) >> 17); + o[6] = stbi__clamp((x1 - t2) >> 17); + o[2] = stbi__clamp((x2 + t1) >> 17); + o[5] = stbi__clamp((x2 - t1) >> 17); + o[3] = stbi__clamp((x3 + t0) >> 17); + o[4] = stbi__clamp((x3 - t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; - - // dot product constant: even elems=x, odd elems=y - #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) - - // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) - // out(1) = c1[even]*x + c1[odd]*y - #define dct_rot(out0,out1, x,y,c0,c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) - - // out = in << 12 (in 16-bit, out 32-bit) - #define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - - // wide add - #define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - - // wide sub - #define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) - - // butterfly a/b, add bias, then shift by "s" and pack - #define dct_bfly32o(out0, out1, a,b,bias,s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } - - // 8-bit interleave step (for transposes) - #define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) - - // 16-bit interleave step (for transposes) - #define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) - - #define dct_pass(bias,shift) \ - { \ - /* even part */ \ - dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ - dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0,row7, x0,x7,bias,shift); \ - dct_bfly32o(row1,row6, x1,x6,bias,shift); \ - dct_bfly32o(row2,row5, x2,x5,bias,shift); \ - dct_bfly32o(row3,row4, x3,x4,bias,shift); \ - } +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + +// dot product constant: even elems=x, odd elems=y +#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + +// out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = \ + _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = \ + _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); - - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); - - // load - row0 = _mm_load_si128((const __m128i *) (data + 0*8)); - row1 = _mm_load_si128((const __m128i *) (data + 1*8)); - row2 = _mm_load_si128((const __m128i *) (data + 2*8)); - row3 = _mm_load_si128((const __m128i *) (data + 3*8)); - row4 = _mm_load_si128((const __m128i *) (data + 4*8)); - row5 = _mm_load_si128((const __m128i *) (data + 5*8)); - row6 = _mm_load_si128((const __m128i *) (data + 6*8)); - row7 = _mm_load_si128((const __m128i *) (data + 7*8)); - - // column pass - dct_pass(bias_0, 10); - - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); - - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); - - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } - - // row pass - dct_pass(bias_1, 17); - - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); - - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... - - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... - - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... +// wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - // store - _mm_storel_epi64((__m128i *) out, p0); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p2); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p1); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p3); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); - } +// wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + +// butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = \ + _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = \ + _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + +// 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + +// 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), + stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), + stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), + stbi__f2f(1.175875602f)); + __m128i rot1_1 = + dct_const(stbi__f2f(1.175875602f), + stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), + stbi__f2f(-1.961570560f)); + __m128i rot2_1 = + dct_const(stbi__f2f(-1.961570560f), + stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), + stbi__f2f(-0.390180644f)); + __m128i rot3_1 = + dct_const(stbi__f2f(-0.390180644f), + stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); + row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); + row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); + row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); + row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); + row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); + row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); + row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *)out, p0); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p2); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p1); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p3); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2634,198 +2893,236 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); - -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) - -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) - -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ - dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ - dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ - dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ - } - - // load - row0 = vld1q_s16(data + 0*8); - row1 = vld1q_s16(data + 1*8); - row2 = vld1q_s16(data + 2*8); - row3 = vld1q_s16(data + 3*8); - row4 = vld1q_s16(data + 4*8); - row5 = vld1q_s16(data + 5*8); - row6 = vld1q_s16(data + 6*8); - row7 = vld1q_s16(data + 7*8); - - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - - // column pass - dct_pass(vrshrn_n_s32, 10); - - // 16bit 8x8 transpose - { +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ + } + + // load + row0 = vld1q_s16(data + 0 * 8); + row1 = vld1q_s16(data + 1 * 8); + row2 = vld1q_s16(data + 2 * 8); + row3 = vld1q_s16(data + 3 * 8); + row4 = vld1q_s16(data + 4 * 8); + row5 = vld1q_s16(data + 5 * 8); + row6 = vld1q_s16(data + 6 * 8); + row7 = vld1q_s16(data + 7 * 8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } -#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } - - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); - - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); - - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = \ + vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ + } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } - - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); - - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } -#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } - - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! - - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); - - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); - - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); - - // store - vst1_u8(out, p0); out += out_stride; - vst1_u8(out, p1); out += out_stride; - vst1_u8(out, p2); out += out_stride; - vst1_u8(out, p3); out += out_stride; - vst1_u8(out, p4); out += out_stride; - vst1_u8(out, p5); out += out_stride; - vst1_u8(out, p6); out += out_stride; - vst1_u8(out, p7); + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ + } +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ + } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); + out += out_stride; + vst1_u8(out, p1); + out += out_stride; + vst1_u8(out, p2); + out += out_stride; + vst1_u8(out, p3); + out += out_stride; + vst1_u8(out, p4); + out += out_stride; + vst1_u8(out, p5); + out += out_stride; + vst1_u8(out, p6); + out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -2838,1132 +3135,1274 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg *j) -{ - stbi_uc x; - if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } - x = stbi__get8(j->s); - if (x != 0xff) return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) { + stbi_uc x; + if (j->marker != STBI__MARKER_none) { + x = j->marker; + j->marker = STBI__MARKER_none; + return x; + } + x = stbi__get8(j->s); + if (x != 0xff) + return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg *j) -{ - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels -} - -static int stbi__parse_entropy_coded_data(stbi__jpeg *z) -{ - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i,j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - STBI_SIMD_ALIGN(short, data[64]); - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x)*8; - int y2 = (j*z->img_comp[n].v + y)*8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; +static void stbi__jpeg_reset(stbi__jpeg *j) { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = + j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) { + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } else { - if (z->scan_n == 1) { - int i,j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x); - int y2 = (j*z->img_comp[n].v + y); - short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } + return 1; + } else { // interleaved + int i, j, k, x, y; + STBI_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, + z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + + z->img_comp[n].w2 * y2 + x2, + z->img_comp[n].w2, data); + } } - } - return 1; + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) -{ - int i; - for (i=0; i < 64; ++i) - data[i] *= dequant[i]; -} - -static void stbi__jpeg_finish(stbi__jpeg *z) -{ - if (z->progressive) { - // dequantize and idct the data - int i,j,n; - for (n=0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - } - } + return 1; + } + } else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], + z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static int stbi__process_marker(stbi__jpeg *z, int m) -{ - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker","Corrupt JPEG"); - - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; - - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s)-2; - while (L > 0) { - int q = stbi__get8(z->s); - int p = q >> 4, sixteen = (p != 0); - int t = q & 15,i; - if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); - if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - - for (i=0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); - L -= (sixteen ? 129 : 65); - } - return L==0; - - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s)-2; - while (L > 0) { - stbi_uc *v; - int sizes[16],i,n=0; - int q = stbi__get8(z->s); - int tc = q >> 4; - int th = q & 15; - if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); - for (i=0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; - } - L -= 17; - if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; - v = z->huff_dc[th].values; - } else { - if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; - v = z->huff_ac[th].values; + return 1; + } else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } } - for (i=0; i < n; ++i) - v[i] = stbi__get8(z->s); - if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); - L -= n; - } - return L==0; - } - - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len","Corrupt JPEG"); - else - return stbi__err("bad APP len","Corrupt JPEG"); + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - L -= 2; - - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J','F','I','F','\0'}; - int ok = 1; - int i; - for (i=0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; - int ok = 1; - int i; - for (i=0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + } } + } + } +} - stbi__skip(z->s, L); - return 1; - } +static int stbi__process_marker(stbi__jpeg *z, int m) { + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker", "Corrupt JPEG"); - return stbi__err("unknown marker","Corrupt JPEG"); -} + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) + return stbi__err("bad DRI len", "Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; -// after we see SOS -static int stbi__process_scan_header(stbi__jpeg *z) -{ - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); - if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); - for (i=0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s) - 2; + while (L > 0) { int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) return 0; // no match - z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); - z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); - z->order[i] = which; - } - - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) + return stbi__err("bad DQT type", "Corrupt JPEG"); + if (t > 3) + return stbi__err("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = + (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + stbi_uc *v; + int sizes[16], i, n = 0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) + return stbi__err("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc + th, sizes)) + return 0; + v = z->huff_dc[th].values; } else { - if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); - z->spec_end = 63; + if (!stbi__build_huffman(z->huff_ac + th, sizes)) + return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len", "Corrupt JPEG"); + else + return stbi__err("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; } - } + } + + stbi__skip(z->s, L); + return 1; + } - return 1; + return stbi__err("unknown marker", "Corrupt JPEG"); } -static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) -{ - int i; - for (i=0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) { + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) + return stbi__err("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) + return stbi__err("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) + return 0; // no match + z->img_comp[which].hd = q >> 4; + if (z->img_comp[which].hd > 3) + return stbi__err("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; + if (z->img_comp[which].ha > 3) + return stbi__err("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || + z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; } -static int stbi__process_frame_header(stbi__jpeg *z, int scan) -{ - stbi__context *s = z->s; - int Lf,p,i,q, h_max=1,v_max=1,c; - Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG - p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); - s->img_n = c; - for (i=0; i < c; ++i) { +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } - - if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); - - z->rgb = 0; - for (i=0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = { 'R', 'G', 'B' }; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); - z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); - } - - if (scan != STBI__SCAN_load) return 1; - - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); - - for (i=0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; - } - - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; - - for (i=0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); - } - } + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) { + stbi__context *s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = stbi__get16be(s); + if (Lf < 11) + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + p = stbi__get8(s); + if (p != 8) + return stbi__err("only 8-bit", + "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); + if (s->img_y == 0) + return stbi__err( + "no header height", + "JPEG format not supported: delayed height"); // Legal, but we don't + // handle it--but neither + // does IJG + s->img_x = stbi__get16be(s); + if (s->img_x == 0) + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) + return stbi__err("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) + return stbi__err("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = {'R', 'G', 'B'}; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); + if (!z->img_comp[i].h || z->img_comp[i].h > 4) + return stbi__err("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; + if (!z->img_comp[i].v || z->img_comp[i].v > 4) + return stbi__err("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); + if (z->img_comp[i].tq > 3) + return stbi__err("bad TQ", "Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) + return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) + return stbi__err("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) + h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) + v_max = z->img_comp[i].v; + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked + // earlier) so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = + stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i + 1, + stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = + (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3( + z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components( + z, i + 1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = + (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } - return 1; + return 1; } // use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) - -#define stbi__SOF_progressive(x) ((x) == 0xc2) - -static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) -{ - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); - if (scan == STBI__SCAN_type) return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z,m)) return 0; +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) + return stbi__err("no SOI", "Corrupt JPEG"); + if (scan == STBI__SCAN_type) + return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z, m)) + return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) + return stbi__err("no SOF", "Corrupt JPEG"); m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) return 0; - return 1; + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) + return 0; + return 1; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg *j) -{ - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) return 0; - if (!stbi__parse_entropy_coded_data(j)) return 0; - if (j->marker == STBI__MARKER_none ) { - // handle 0s at the end of image data from IP Kamera 9060 - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - if (x == 255) { - j->marker = stbi__get8(j->s); - break; - } - } - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); - } else { - if (!stbi__process_marker(j, m)) return 0; +static int stbi__decode_jpeg_image(stbi__jpeg *j) { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) + return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) + return 0; + if (!stbi__parse_entropy_coded_data(j)) + return 0; + if (j->marker == STBI__MARKER_none) { + // handle 0s at the end of image data from IP Kamera 9060 + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + if (x == 255) { + j->marker = stbi__get8(j->s); + break; + } + } + // if we reach eof without hitting a marker, stbi__get_marker() below + // will fail and we'll eventually return 0 } - m = stbi__get_marker(j); - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) + return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) + return stbi__err("bad DNL height", "Corrupt JPEG"); + } else { + if (!stbi__process_marker(j, m)) + return 0; + } + m = stbi__get_marker(j); + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, - int w, int hs); - -#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) - -static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; -} - -static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i=0; i < w; ++i) - out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); - return out; + int w, int hs); + +#define stbi__div4(x) ((stbi_uc)((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, + int w, int hs) { + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc *stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc *stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = stbi__div4(n + input[i - 1]); + out[i * 2 + 1] = stbi__div4(n + input[i + 1]); + } + out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc)((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = stbi__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; } -static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples horizontally for every one in input - int i; - stbi_uc *input = in_near; - - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } - - out[0] = input[0]; - out[1] = stbi__div4(input[0]*3 + input[1] + 2); - for (i=1; i < w-1; ++i) { - int n = 3*input[i]+2; - out[i*2+0] = stbi__div4(n+input[i-1]); - out[i*2+1] = stbi__div4(n+input[i+1]); - } - out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); - out[i*2+1] = input[w-1]; - - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); - - return out; -} +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = + _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *)(out + i * 2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = + vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif -#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } -static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i,t0,t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - t1 = 3*in_near[0] + in_far[0]; - out[0] = stbi__div4(t1+2); - for (i=1; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } +#endif -#if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i=0,t0,t1; - - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } - - t1 = 3*in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w-1) & ~7); i += 8) { -#if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); - - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); - - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *) (out + i*2), outv); -#elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); - - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i*2, o); -#endif - - // "previous" value for next iter - t1 = 3*in_near[i+7] + in_far[i+7]; - } - - t0 = t1; - t1 = 3*in_near[i] + in_far[i]; - out[i*2] = stbi__div16(3*t1 + t0 + 8); - - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); - - STBI_NOTUSED(hs); - - return out; -} -#endif - -static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // resample with nearest-neighbor - int i,j; - STBI_NOTUSED(in_far); - for (i=0; i < w; ++i) - for (j=0; j < hs; ++j) - out[i*hs+j] = in_near[i]; - return out; +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // resample with nearest-neighbor + int i, j; + STBI_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) -{ - int i; - for (i=0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int)((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, + int count, int step) { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) -{ - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, + stbi_uc const *pcb, stbi_uc const *pcr, + int count, int step) { + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); - __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); - __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); - __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); - __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); - __m128i xw = _mm_set1_epi16(255); // alpha channel - - for (; i+7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); - - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); - - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); - - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); - - // store - _mm_storeu_si128((__m128i *) (out + 0), o0); - _mm_storeu_si128((__m128i *) (out + 16), o1); - out += 32; - } - } + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *)(out + 0), o0); + _mm_storeu_si128((__m128i *)(out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); - int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); - int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); - int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); - - for (; i+7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); - - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); - - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); - - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8*4; - } - } -#endif - - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + cr * -stbi__float2fixed(0.71414f) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg *j) -{ - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) { + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg *j) -{ - stbi__free_jpeg_components(j, j->s->img_n, 0); +static void stbi__cleanup_jpeg(stbi__jpeg *j) { + stbi__free_jpeg_components(j, j->s->img_n, 0); } -typedef struct -{ - resample_row_func resample; - stbi_uc *line0,*line1; - int hs,vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +typedef struct { + resample_row_func resample; + stbi_uc *line0, *line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) -{ - unsigned int t = x*y + 128; - return (stbi_uc) ((t + (t >>8)) >> 8); -} - -static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) -{ - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe - - // validate req_comp - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } - - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; - - // resample and color-convert - { - int k; - unsigned int i,j; - stbi_uc *output; - stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; - - stbi__resample res_comp[4]; - - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs-1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; - else r->resample = stbi__resample_row_generic; +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { + unsigned int t = x * y + 128; + return (stbi_uc)((t + (t >> 8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, + int *comp, int req_comp) { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && + (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // resample and color-convert + { + int k; + unsigned int i, j; + stbi_uc *output; + stbi_uc *coutput[4] = {NULL, NULL, NULL, NULL}; + + stbi__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); } - // can't error after this so, this is safe - output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - // now go ahead and resample - for (j=0; j < z->s->img_y; ++j) { - stbi_uc *out = output + n * z->s->img_x * j; - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, - y_bot ? r->line1 : r->line0, - y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) + r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) + r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) + r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) + r->resample = z->resample_row_hv_2_kernel; + else + r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = + r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; } - } - if (n >= 3) { - stbi_uc *y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i=0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } - } else { - if (is_rgb) { - if (n == 1) - for (i=0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i=0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc *y = coutput[0]; - if (n == 1) - for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; - else - for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; } - } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { + *out++ = y[i]; + *out++ = 255; + } + } } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } -} - -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - unsigned char* result; - stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x,y,comp,req_comp); - STBI_FREE(j); - return result; -} - -static int stbi__jpeg_test(stbi__context *s) -{ - int r; - stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; -} - -static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) -{ - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind( j->s ); - return 0; - } - if (x) *x = j->s->img_x; - if (y) *y = j->s->img_y; - if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; -} - -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) -{ - int result; - stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) + *comp = + z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + unsigned char *result; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) { + int r; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) { + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind(j->s); + return 0; + } + if (x) + *x = j->s->img_x; + if (y) + *y = j->s->img_y; + if (comp) + *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) { + int result; + stbi__jpeg *j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -3977,83 +4416,81 @@ static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct -{ - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[288]; - stbi__uint16 value[288]; +typedef struct { + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[288]; + stbi__uint16 value[288]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) -{ - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); +stbi_inline static int stbi__bitreverse16(int n) { + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); return n; } -stbi_inline static int stbi__bit_reverse(int v, int bits) -{ - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16-bits); -} - -static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) -{ - int i,k=0; - int code, next_code[16], sizes[17]; - - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i=0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i=1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i=1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16) code; - z->firstsymbol[i] = (stbi__uint16) k; - code = (code + sizes[i]); - if (sizes[i]) - if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); - z->maxcode[i] = code << (16-i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i=0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); - z->size [c] = (stbi_uc ) s; - z->value[c] = (stbi__uint16) i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s],s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } - } - ++next_code[s]; +stbi_inline static int stbi__bit_reverse(int v, int bits) { + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16 - bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, + int num) { + int i, k = 0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i = 0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i = 1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i = 1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16)code; + z->firstsymbol[i] = (stbi__uint16)k; + code = (code + sizes[i]); + if (sizes[i]) + if (code - 1 >= (1 << i)) + return stbi__err("bad codelengths", "Corrupt PNG"); + z->maxcode[i] = code << (16 - i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i = 0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); + z->size[c] = (stbi_uc)s; + z->value[c] = (stbi__uint16)i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s], s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } } - } - return 1; + ++next_code[s]; + } + } + return 1; } // zlib-from-memory implementation for PNG reading @@ -4062,277 +4499,313 @@ static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int // we require PNG read all the IDATs and combine them into a single // memory buffer -typedef struct -{ - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; +typedef struct { + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; - char *zout; - char *zout_start; - char *zout_end; - int z_expandable; + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; - stbi__zhuffman z_length, z_distance; + stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf *z) -{ - return (z->zbuffer >= z->zbuffer_end); -} - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) -{ - return stbi__zeof(z) ? 0 : *z->zbuffer++; -} - -static void stbi__fill_bits(stbi__zbuf *z) -{ - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); +stbi_inline static int stbi__zeof(stbi__zbuf *z) { + return (z->zbuffer >= z->zbuffer_end); } -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) -{ - unsigned int k; - if (z->num_bits < n) stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) { + return stbi__zeof(z) ? 0 : *z->zbuffer++; } -static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s,k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s=STBI__ZFAST_BITS+1; ; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) return -1; // invalid code! - // code size is s, so: - b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= sizeof (z->size)) return -1; // some data was corrupt somewhere! - if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes -{ - char *q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); - cur = (unsigned int) (z->zout - z->zout_start); - limit = old_limit = (unsigned) (z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; +static void stbi__fill_bits(stbi__zbuf *z) { + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) { + unsigned int k; + if (z->num_bits < n) + stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s, k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s = STBI__ZFAST_BITS + 1;; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) + return -1; // invalid code! + // code size is s, so: + b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= sizeof(z->size)) + return -1; // some data was corrupt somewhere! + if (z->size[b] != s) + return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, + int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) + return stbi__err("output buffer limit", "Corrupt PNG"); + cur = (unsigned int)(z->zout - z->zout_start); + limit = old_limit = (unsigned)(z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned)n) + return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if (limit > UINT_MAX / 2) + return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) + return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; } static const int stbi__zlength_base[31] = { - 3,4,5,6,7,8,9,10,11,13, - 15,17,19,23,27,31,35,43,51,59, - 67,83,99,115,131,163,195,227,258,0,0 }; - -static const int stbi__zlength_extra[31]= -{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; - -static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, -257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; - -static const int stbi__zdist_extra[32] = -{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; - -static int stbi__parse_huffman_block(stbi__zbuf *a) -{ - char *zout = a->zout; - for(;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) return 0; - zout = a->zout; - } - *zout++ = (char) z; + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + +static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, + 4, 4, 5, 5, 5, 5, 0, 0, 0}; + +static const int stbi__zdist_base[32] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; + +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) { + char *zout = a->zout; + for (;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) + return stbi__err("bad huffman code", + "Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) + return 0; + zout = a->zout; + } + *zout++ = (char)z; + } else { + stbi_uc *p; + int len, dist; + if (z == 256) { + a->zout = zout; + return 1; + } + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) + len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0) + return stbi__err("bad huffman code", "Corrupt PNG"); + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) + dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) + return stbi__err("bad dist", "Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) + return 0; + zout = a->zout; + } + p = (stbi_uc *)(zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { + do + *zout++ = v; + while (--len); + } } else { - stbi_uc *p; - int len,dist; - if (z == 256) { - a->zout = zout; - return 1; - } - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) return 0; - zout = a->zout; - } - p = (stbi_uc *) (zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { do *zout++ = v; while (--len); } - } else { - if (len) { do *zout++ = *p++; while (--len); } - } + if (len) { + do + *zout++ = *p++; + while (--len); + } } - } -} - -static int stbi__compute_huffman_codes(stbi__zbuf *a) -{ - static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286+32+137];//padding for maximum single op - stbi_uc codelength_sizes[19]; - int i,n; - - int hlit = stbi__zreceive(a,5) + 257; - int hdist = stbi__zreceive(a,5) + 1; - int hclen = stbi__zreceive(a,4) + 4; - int ntot = hlit + hdist; - - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i=0; i < hclen; ++i) { - int s = stbi__zreceive(a,3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; - - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc) c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a,2)+3; - if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n-1]; - } else if (c == 17) { - c = stbi__zreceive(a,3)+3; - } else if (c == 18) { - c = stbi__zreceive(a,7)+11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); - } - if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes+n, fill, c); - n += c; + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) { + static const stbi_uc length_dezigzag[19] = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc codelength_sizes[19]; + int i, n; + + int hlit = stbi__zreceive(a, 5) + 257; + int hdist = stbi__zreceive(a, 5) + 1; + int hclen = stbi__zreceive(a, 4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i = 0; i < hclen; ++i) { + int s = stbi__zreceive(a, 3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) + return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc)c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a, 2) + 3; + if (n == 0) + return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n - 1]; + } else if (c == 17) { + c = stbi__zreceive(a, 3) + 3; + } else if (c == 18) { + c = stbi__zreceive(a, 7) + 11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); } - } - if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; - return 1; -} - -static int stbi__parse_uncompressed_block(stbi__zbuf *a) -{ - stbi_uc header[4]; - int len,nlen,k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; -} - -static int stbi__parse_zlib_header(stbi__zbuf *a) -{ - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; -} - -static const stbi_uc stbi__zdefault_length[288] = -{ - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 -}; -static const stbi_uc stbi__zdefault_distance[32] = -{ - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 -}; + if (ntot - n < c) + return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes + n, fill, c); + n += c; + } + } + if (n != ntot) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) + return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) { + stbi_uc header[4]; + int len, nlen, k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = + (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) + return stbi__err("zlib corrupt", "Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) + return stbi__err("zlib corrupt", "Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) + return stbi__err("read past buffer", "Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) + return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) { + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if ((cmf * 256 + flg) % 31 != 0) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if (flg & 32) + return stbi__err("no preset dict", + "Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) + return stbi__err("bad compression", + "Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[288] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; +static const stbi_uc stbi__zdefault_distance[32] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; /* Init algorithm: { @@ -4346,117 +4819,131 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) -{ - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a,1); - type = stbi__zreceive(a,2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) return 0; - } else if (type == 3) { - return 0; +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) { + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) + return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a, 1); + type = stbi__zreceive(a, 2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) + return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, 288)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) + return 0; } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , 288)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; - } else { - if (!stbi__compute_huffman_codes(a)) return 0; - } - if (!stbi__parse_huffman_block(a)) return 0; + if (!stbi__compute_huffman_codes(a)) + return 0; } - } while (!final); - return 1; -} - -static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) -{ - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; - - return stbi__parse_zlib(a, parse_header); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) -{ - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int) (a.zout - a.zout_start); - else - return -1; -} - -STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(16384); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer+len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int) (a.zout - a.zout_start); - else - return -1; + if (!stbi__parse_huffman_block(a)) + return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, + int parse_header) { + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, + int *outlen) { + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + char const *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int)(a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, + int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(16384); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int)(a.zout - a.zout_start); + else + return -1; } #endif @@ -4471,1083 +4958,1312 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct -{ - stbi__uint32 length; - stbi__uint32 type; +typedef struct { + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) -{ - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) { + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context *s) -{ - static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; - int i; - for (i=0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) { + static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + int i; + for (i = 0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) + return stbi__err("bad png sig", "Not a PNG"); + return 1; } -typedef struct -{ - stbi__context *s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct { + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; - enum { - STBI__F_none=0, - STBI__F_sub=1, - STBI__F_up=2, - STBI__F_avg=3, - STBI__F_paeth=4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none = 0, + STBI__F_sub = 1, + STBI__F_up = 2, + STBI__F_avg = 3, + STBI__F_paeth = 4, + // synthetic filters used for first scanline to avoid needing a dummy row of + // 0s + STBI__F_avg_first, + STBI__F_paeth_first }; -static stbi_uc first_row_filter[5] = -{ - STBI__F_none, - STBI__F_sub, - STBI__F_none, - STBI__F_avg_first, - STBI__F_paeth_first -}; +static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, + STBI__F_avg_first, STBI__F_paeth_first}; -static int stbi__paeth(int a, int b, int c) -{ - int p = a + b - c; - int pa = abs(p-a); - int pb = abs(p-b); - int pc = abs(p-c); - if (pa <= pb && pa <= pc) return a; - if (pb <= pc) return b; - return c; +static int stbi__paeth(int a, int b, int c) { + int p = a + b - c; + int pa = abs(p - a); + int pb = abs(p - b); + int pc = abs(p - c); + if (pa <= pb && pa <= pc) + return a; + if (pb <= pc) + return b; + return c; } -static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; +static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, + 0, 0, 0, 0x01}; // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) -{ - int bytes = (depth == 16? 2 : 1); - stbi__context *s = a->s; - stbi__uint32 i,j,stride = x*out_n*bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later - - int output_bytes = out_n*bytes; - int filter_bytes = img_n*bytes; - int width = x; - - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); - a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) return stbi__err("outofmem", "Out of memory"); - - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; - - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); - - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *prior; - int filter = *raw++; - - if (filter > 4) - return stbi__err("invalid filter","Corrupt PNG"); - - if (depth < 8) { - if (img_width_bytes > x) return stbi__err("invalid width","Corrupt PNG"); - cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - - // if first row, use special filter that doesn't sample previous row - if (j == 0) filter = first_row_filter[filter]; - - // handle first byte explicitly - for (k=0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none : cur[k] = raw[k]; break; - case STBI__F_sub : cur[k] = raw[k]; break; - case STBI__F_up : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break; - case STBI__F_avg : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break; - case STBI__F_paeth : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break; - case STBI__F_avg_first : cur[k] = raw[k]; break; - case STBI__F_paeth_first: cur[k] = raw[k]; break; - } +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, + stbi__uint32 raw_len, int out_n, + stbi__uint32 x, stbi__uint32 y, int depth, + int color) { + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i, j, stride = x * out_n * bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n * bytes; + int filter_bytes = img_n * bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); + a->out = (stbi_uc *)stbi__malloc_mad3( + x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) + return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) + return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on + // non-interlaced PNGs, but issue #276 reported a PNG in the wild that had + // extra data at the end (all zeros), so just check for raw_len < img_len + // always. + if (raw_len < img_len) + return stbi__err("not enough pixels", "Corrupt PNG"); + + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter", "Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) + return stbi__err("invalid width", "Corrupt PNG"); + cur += + x * out_n - img_width_bytes; // store output to the rightmost img_len + // bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = + cur - + stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) + filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k = 0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none: + cur[k] = raw[k]; + break; + case STBI__F_sub: + cur[k] = raw[k]; + break; + case STBI__F_up: + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); + break; + case STBI__F_paeth: + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); + break; + case STBI__F_avg_first: + cur[k] = raw[k]; + break; + case STBI__F_paeth_first: + cur[k] = raw[k]; + break; } + } - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes+1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1)*filter_bytes; - #define STBI__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: memcpy(cur, raw, nk); break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break; - } - #undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n+1 == out_n); - #define STBI__CASE(f) \ - case f: \ - for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break; - } - #undef STBI__CASE - - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride*j; // start at the beginning of the row again - for (i=0; i < x; ++i,cur+=output_bytes) { - cur[filter_bytes+1] = 255; - } - } + // this is a little gross, so that we don't switch per-pixel or + // per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1) * filter_bytes; +#define STBI__CASE(f) \ + case f: \ + for (k = 0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: + memcpy(cur, raw, nk); + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - filter_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - filter_bytes], prior[k], + prior[k - filter_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); + } + break; } - } - - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *in = a->out + stride*j + x*out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit - // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k=x*img_n; k >= 2; k-=2, ++in) { - *cur++ = scale * ((*in >> 4) ); - *cur++ = scale * ((*in ) & 0x0f); - } - if (k > 0) *cur++ = scale * ((*in >> 4) ); - } else if (depth == 2) { - for (k=x*img_n; k >= 4; k-=4, ++in) { - *cur++ = scale * ((*in >> 6) ); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in ) & 0x03); - } - if (k > 0) *cur++ = scale * ((*in >> 6) ); - if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k=x*img_n; k >= 8; k-=8, ++in) { - *cur++ = scale * ((*in >> 7) ); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in ) & 0x01); - } - if (k > 0) *cur++ = scale * ((*in >> 7) ); - if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride*j; - if (img_n == 1) { - for (q=x-1; q >= 0; --q) { - cur[q*2+1] = 255; - cur[q*2+0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q=x-1; q >= 0; --q) { - cur[q*4+3] = 255; - cur[q*4+2] = cur[q*3+2]; - cur[q*4+1] = cur[q*3+1]; - cur[q*4+0] = cur[q*3+0]; - } - } - } +#undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n + 1 == out_n); +#define STBI__CASE(f) \ + case f: \ + for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, \ + cur += output_bytes, prior += output_bytes) \ + for (k = 0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { + cur[k] = raw[k]; + } + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - output_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - output_bytes], prior[k], + prior[k - output_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); + } + break; } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc *cur = a->out; - stbi__uint16 *cur16 = (stbi__uint16*)cur; - - for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) { - *cur16 = (cur[0] << 8) | cur[1]; +#undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride * j; // start at the beginning of the row again + for (i = 0; i < x; ++i, cur += output_bytes) { + cur[filter_bytes + 1] = 255; + } } - } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *in = a->out + stride * j + x * out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common + // 8-bit path optimal at minimal cost for 1/2/4-bit png guarante byte + // alignment, if width is not multiple of 8/4/2 we'll decode dummy + // trailing data that will be skipped in the later loop + stbi_uc scale = (color == 0) + ? stbi__depth_scale_table[depth] + : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than + // desired. we can allocate enough data that this never writes out of + // memory, but it could also overwrite the next scanline. can it overwrite + // non-empty data on the next scanline? yes, consider 1-pixel-wide + // scanlines with 1-bit-per-pixel. so we need to explicitly clamp the + // final ones + + if (depth == 4) { + for (k = x * img_n; k >= 2; k -= 2, ++in) { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } else if (depth == 2) { + for (k = x * img_n; k >= 4; k -= 4, ++in) { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k = x * img_n; k >= 8; k -= 8, ++in) { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) + *cur++ = scale * ((*in >> 7)); + if (k > 1) + *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) + *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) + *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) + *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) + *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) + *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride * j; + if (img_n == 1) { + for (q = x - 1; q >= 0; --q) { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q = x - 1; q >= 0; --q) { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } + } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16 *)cur; + + for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, + stbi__uint32 image_data_len, int out_n, + int depth, int color, int interlaced) { + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, + a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + for (p = 0; p < 7; ++p) { + int xorig[] = {0, 4, 0, 2, 0, 1, 0}; + int yorig[] = {0, 0, 4, 0, 2, 0, 1}; + int xspc[] = {8, 8, 4, 4, 2, 2, 1}; + int yspc[] = {8, 8, 8, 4, 4, 2, 2}; + int i, j, x, y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, + y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j = 0; j < y; ++j) { + for (i = 0; i < x; ++i) { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, + a->out + (j * x + i) * out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - return 1; + return 1; } -static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) -{ - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc *final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - - // de-interlacing - final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - for (p=0; p < 7; ++p) { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; - int i,j,x,y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; - } - for (j=0; j < y; ++j) { - for (i=0; i < x; ++i) { - int out_y = j*yspc[p]+yorig[p]; - int out_x = i*xspc[p]+xorig[p]; - memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, - a->out + (j*x+i)*out_bytes, out_bytes); - } - } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - return 1; -} + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); -static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); - - if (out_n == 2) { - for (i=0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i=0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 *p = (stbi__uint16*) z->out; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], + int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16 *)z->out; - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) -{ - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; - - p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) return stbi__err("outofmem", "Out of memory"); - - // between here and free(out) below, exitting would leak - temp_out = p; - - if (pal_img_n == 3) { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p += 3; - } - } else { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p[3] = palette[n+3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, + int pal_img_n) { + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); - STBI_NOTUSED(len); + // between here and free(out) below, exitting would leak + temp_out = p; - return 1; + if (pal_img_n == 3) { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p += 3; + } + } else { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p[3] = palette[n + 3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; } static int stbi__unpremultiply_on_load = 0; static int stbi__de_iphone_flag = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) -{ - stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) -{ - stbi__de_iphone_flag = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { + stbi__de_iphone_flag = flag_true_if_should_convert; } -static void stbi__de_iphone(stbi__png *z) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - if (s->img_out_n == 3) { // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 3; +static void stbi__de_iphone(stbi__png *z) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i = 0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = (t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i=0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = ( t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } + } else { + // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a, b, c, d) \ + (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + \ + (unsigned)(d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) { + stbi_uc palette[1024], pal_img_n = 0; + stbi_uc has_trans = 0, tc[3] = {0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; + int first = 1, k, interlace = 0, color = 0, is_iphone = 0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) + return 0; + + if (scan == STBI__SCAN_type) + return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { + int comp, filter; + if (!first) + return stbi__err("multiple IHDR", "Corrupt PNG"); + first = 0; + if (c.length != 13) + return stbi__err("bad IHDR len", "Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + z->depth = stbi__get8(s); + if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && + z->depth != 16) + return stbi__err("1/2/4/8/16-bit only", + "PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); + if (color > 6) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3 && z->depth == 16) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3) + pal_img_n = 3; + else if (color & 1) + return stbi__err("bad ctype", "Corrupt PNG"); + comp = stbi__get8(s); + if (comp) + return stbi__err("bad comp method", "Corrupt PNG"); + filter = stbi__get8(s); + if (filter) + return stbi__err("bad filter method", "Corrupt PNG"); + interlace = stbi__get8(s); + if (interlace > 1) + return stbi__err("bad interlace method", "Corrupt PNG"); + if (!s->img_x || !s->img_y) + return stbi__err("0-pixel image", "Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) + return stbi__err("too large", "Image too large to decode"); + if (scan == STBI__SCAN_header) + return 1; } else { - // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) + return stbi__err("too large", "Corrupt PNG"); + // if SCAN_header, have to scan to see if we have a tRNS } - } -} - -#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + break; + } -static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) -{ - stbi_uc palette[1024], pal_img_n=0; - stbi_uc has_trans=0, tc[3]={0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; - int first=1,k,interlace=0, color=0, is_iphone=0; - stbi__context *s = z->s; - - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; - - if (!stbi__check_png_header(s)) return 0; - - if (scan == STBI__SCAN_type) return 1; - - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C','g','B','I'): - is_iphone = 1; - stbi__skip(s, c.length); - break; - case STBI__PNG_TYPE('I','H','D','R'): { - int comp,filter; - if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); - first = 0; - if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); - s->img_x = stbi__get32be(s); - s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); - comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); - filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); - interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); - if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); - if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); - if (scan == STBI__SCAN_header) return 1; - } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); - // if SCAN_header, have to scan to see if we have a tRNS - } - break; - } - - case STBI__PNG_TYPE('P','L','T','E'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); - pal_len = c.length / 3; - if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); - for (i=0; i < pal_len; ++i) { - palette[i*4+0] = stbi__get8(s); - palette[i*4+1] = stbi__get8(s); - palette[i*4+2] = stbi__get8(s); - palette[i*4+3] = 255; - } - break; - } - - case STBI__PNG_TYPE('t','R','N','S'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); - if (pal_img_n) { - if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } - if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); - if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); - pal_img_n = 4; - for (i=0; i < c.length; ++i) - palette[i*4+3] = stbi__get8(s); - } else { - if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); - if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); - has_trans = 1; - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } - } - break; - } - - case STBI__PNG_TYPE('I','D','A','T'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); - if (scan == STBI__SCAN_header) { s->img_n = pal_img_n; return 1; } - if ((int)(ioff + c.length) < (int)ioff) return 0; - if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc *p; - if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); - z->idata = p; - } - if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); - ioff += c.length; - break; - } - - case STBI__PNG_TYPE('I','E','N','D'): { - stbi__uint32 raw_len, bpl; - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) return 1; - if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); - // initial guess for decoded data size to avoid unnecessary reallocs - bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component - raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); - if (z->expanded == NULL) return 0; // zlib should set error - STBI_FREE(z->idata); z->idata = NULL; - if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n+1; - else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; - if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; - } - } - if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); - if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; - } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; - } - STBI_FREE(z->expanded); z->expanded = NULL; - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - return 1; - } - - default: - // if critical, fail - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if ((c.type & (1 << 29)) == 0) { - #ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); - #endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); - } - stbi__skip(s, c.length); - break; + case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256 * 3) + return stbi__err("invalid PLTE", "Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) + return stbi__err("invalid PLTE", "Corrupt PNG"); + for (i = 0; i < pal_len; ++i) { + palette[i * 4 + 0] = stbi__get8(s); + palette[i * 4 + 1] = stbi__get8(s); + palette[i * 4 + 2] = stbi__get8(s); + palette[i * 4 + 3] = 255; } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } -} + break; + } -static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) -{ - void *result=NULL; - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) return result; + case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) + return stbi__err("tRNS after IDAT", "Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { + s->img_n = 4; + return 1; + } + if (pal_len == 0) + return stbi__err("tRNS before PLTE", "Corrupt PNG"); + if (c.length > pal_len) + return stbi__err("bad tRNS len", "Corrupt PNG"); + pal_img_n = 4; + for (i = 0; i < c.length; ++i) + palette[i * 4 + 3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) + return stbi__err("tRNS with alpha", "Corrupt PNG"); + if (c.length != (stbi__uint32)s->img_n * 2) + return stbi__err("bad tRNS len", "Corrupt PNG"); + has_trans = 1; + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * + stbi__depth_scale_table[z->depth]; // non 8-bit images will + // be larger + } } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) *n = p->s->img_n; - } - STBI_FREE(p->out); p->out = NULL; - STBI_FREE(p->expanded); p->expanded = NULL; - STBI_FREE(p->idata); p->idata = NULL; - - return result; -} - -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi__png p; - p.s = s; - return stbi__do_png(&p, x,y,comp,req_comp, ri); -} - -static int stbi__png_test(stbi__context *s) -{ - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; -} + break; + } -static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) -{ - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind( p->s ); - return 0; - } - if (x) *x = p->s->img_x; - if (y) *y = p->s->img_y; - if (comp) *comp = p->s->img_n; - return 1; -} + case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) + return stbi__err("no PLTE", "Corrupt PNG"); + if (scan == STBI__SCAN_header) { + s->img_n = pal_img_n; + return 1; + } + if ((int)(ioff + c.length) < (int)ioff) + return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) + idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, + idata_limit); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata + ioff, c.length)) + return stbi__err("outofdata", "Corrupt PNG"); + ioff += c.length; + break; + } -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); -} + case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + stbi__uint32 raw_len, bpl; + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) + return 1; + if (z->idata == NULL) + return stbi__err("no IDAT", "Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag( + (char *)z->idata, ioff, raw_len, (int *)&raw_len, !is_iphone); + if (z->expanded == NULL) + return 0; // zlib should set error + STBI_FREE(z->idata); + z->idata = NULL; + if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || + has_trans) + s->img_out_n = s->img_n + 1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, + z->depth, color, interlace)) + return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) + return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) + return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) + s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); + z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } -static int stbi__png_is16(stbi__context *s) -{ - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; + default: + // if critical, fail + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { +#ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); +#endif + return stbi__err(invalid_chunk, + "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, + stbi__result_info *ri) { + void *result = NULL; + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", + "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) + return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) + *n = p->s->img_n; + } + STBI_FREE(p->out); + p->out = NULL; + STBI_FREE(p->expanded); + p->expanded = NULL; + STBI_FREE(p->idata); + p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi__png p; + p.s = s; + return stbi__do_png(&p, x, y, comp, req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) { + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) { + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind(p->s); + return 0; + } + if (x) + *x = p->s->img_x; + if (y) + *y = p->s->img_y; + if (comp) + *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) { + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) { + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context *s) -{ - int r; - int sz; - if (stbi__get8(s) != 'B') return 0; - if (stbi__get8(s) != 'M') return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; -} - -static int stbi__bmp_test(stbi__context *s) -{ - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test_raw(stbi__context *s) { + int r; + int sz; + if (stbi__get8(s) != 'B') + return 0; + if (stbi__get8(s) != 'M') + return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) { + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } - // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) -{ - int n=0; - if (z == 0) return -1; - if (z >= 0x10000) { n += 16; z >>= 16; } - if (z >= 0x00100) { n += 8; z >>= 8; } - if (z >= 0x00010) { n += 4; z >>= 4; } - if (z >= 0x00004) { n += 2; z >>= 2; } - if (z >= 0x00002) { n += 1;/* >>= 1;*/ } - return n; +static int stbi__high_bit(unsigned int z) { + int n = 0; + if (z == 0) + return -1; + if (z >= 0x10000) { + n += 16; + z >>= 16; + } + if (z >= 0x00100) { + n += 8; + z >>= 8; + } + if (z >= 0x00010) { + n += 4; + z >>= 4; + } + if (z >= 0x00004) { + n += 2; + z >>= 2; + } + if (z >= 0x00002) { + n += 1; /* >>= 1;*/ + } + return n; } -static int stbi__bitcount(unsigned int a) -{ - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) { + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) -{ - static unsigned int mul_table[9] = { +static int stbi__shiftsigned(unsigned int v, int shift, int bits) { + static unsigned int mul_table[9] = { 0, - 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, - 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0,0,1,0,2,4,6,0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8-bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; -} - -typedef struct -{ - int bpp, offset, hsz; - unsigned int mr,mg,mb,ma, all_a; - int extra_read; + 0xff /*0b11111111*/, + 0x55 /*0b01010101*/, + 0x49 /*0b01001001*/, + 0x11 /*0b00010001*/, + 0x21 /*0b00100001*/, + 0x41 /*0b01000001*/, + 0x81 /*0b10000001*/, + 0x01 /*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0, 0, 1, 0, 2, 4, 6, 0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8 - bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct { + int bpp, offset, hsz; + unsigned int mr, mg, mb, ma, all_a; + int extra_read; } stbi__bmp_data; -static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) -{ - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; - - if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); - - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - int i; - if (hsz != 108 && hsz != 124) +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) { + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') + return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) + return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) + return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) + return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) + return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha + // channel but it was all 0 + } else { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - stbi__get32le(s); // discard color space - for (i=0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); } - } - return (void *) 1; -} - - -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - unsigned int mr=0,mg=0,mb=0,ma=0, all_a; - stbi_uc pal[256][4]; - int psize=0,i,j,width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set - - flip_vertically = ((int) s->img_y) > 0; - s->img_y = abs((int) s->img_y); - - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; - - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - STBI_ASSERT(info.offset == s->callback_already_read + (int) (s->img_buffer - s->img_buffer_original)); - if (info.offset != s->callback_already_read + (s->img_buffer - s->buffer_start)) { - return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + stbi__get32le(s); // discard color space + for (i = 0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved } - } - - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert - - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); - - out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z=0; - if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } - for (i=0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) stbi__get8(s); - pal[i][3] = 255; + } + } + return (void *)1; +} + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; + stbi_uc pal[256][4]; + int psize = 0, i, j, width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int)s->img_y) > 0; + s->img_y = abs((int)s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + STBI_ASSERT(info.offset == + s->callback_already_read + + (int)(s->img_buffer - s->img_buffer_original)); + if (info.offset != + s->callback_already_read + (s->img_buffer - s->buffer_start)) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z = 0; + if (psize == 0 || psize > 256) { + STBI_FREE(out); + return stbi__errpuc("invalid", "Corrupt BMP"); + } + for (i = 0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) + stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - + psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) + width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) + width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) + width = s->img_x; + else { + STBI_FREE(out); + return stbi__errpuc("bad bpp", "Corrupt BMP"); + } + pad = (-width) & 3; + if (info.bpp == 1) { + for (j = 0; j < (int)s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i = 0; i < (int)s->img_x; ++i) { + int color = (v >> bit_offset) & 0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + if ((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) width = s->img_x; - else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } - pad = (-width)&3; - if (info.bpp == 1) { - for (j=0; j < (int) s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i=0; i < (int) s->img_x; ++i) { - int color = (v>>bit_offset)&0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - if((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j=0; j < (int) s->img_y; ++j) { - for (i=0; i < (int) s->img_x; i += 2) { - int v=stbi__get8(s),v2=0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - } - stbi__skip(s, pad); - } + } else { + for (j = 0; j < (int)s->img_y; ++j) { + for (i = 0; i < (int)s->img_x; i += 2) { + int v = stbi__get8(s), v2 = 0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + } + stbi__skip(s, pad); } - } else { - int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; - int z = 0; - int easy=0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) width = 3 * s->img_x; - else if (info.bpp == 16) width = 2*s->img_x; - else /* bpp = 32 and pad = 0 */ width=0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; + } + } else { + int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, + bcount = 0, acount = 0; + int z = 0; + int easy = 0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) + width = 3 * s->img_x; + else if (info.bpp == 16) + width = 2 * s->img_x; + else /* bpp = 32 and pad = 0 */ + width = 0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - if (!easy) { - if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr) - 7; + rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg) - 7; + gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb) - 7; + bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma) - 7; + acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - for (j=0; j < (int) s->img_y; ++j) { - if (easy) { - for (i=0; i < (int) s->img_x; ++i) { - unsigned char a; - out[z+2] = stbi__get8(s); - out[z+1] = stbi__get8(s); - out[z+0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i=0; i < (int) s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) out[z++] = STBI__BYTECAST(a); - } - } - stbi__skip(s, pad); + } + for (j = 0; j < (int)s->img_y; ++j) { + if (easy) { + for (i = 0; i < (int)s->img_x; ++i) { + unsigned char a; + out[z + 2] = stbi__get8(s); + out[z + 1] = stbi__get8(s); + out[z + 0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) + out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i = 0; i < (int)s->img_x; ++i) { + stbi__uint32 v = + (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) + out[z++] = STBI__BYTECAST(a); + } } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j=0; j < (int) s->img_y>>1; ++j) { - stbi_uc *p1 = out + j *s->img_x*target; - stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; - for (i=0; i < (int) s->img_x*target; ++i) { - t = p1[i]; p1[i] = p2[i]; p2[i] = t; - } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j = 0; j < (int)s->img_y >> 1; ++j) { + stbi_uc *p1 = out + j * s->img_x * target; + stbi_uc *p2 = out + (s->img_y - 1 - j) * s->img_x * target; + for (i = 0; i < (int)s->img_x * target; ++i) { + t = p1[i]; + p1[i] = p2[i]; + p2[i] = t; } - } + } + } - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; - return out; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + return out; } #endif @@ -5555,592 +6271,625 @@ static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) -{ - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) *is_rgb16 = 0; - switch(bits_per_pixel) { - case 8: return STBI_grey; - case 16: if(is_grey) return STBI_grey_alpha; - // fallthrough - case 15: if(is_rgb16) *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: return bits_per_pixel/8; - default: return 0; - } -} - -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) -{ - int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; - int sz, tga_colormap_type; - stbi__get8(s); // discard Offset - tga_colormap_type = stbi__get8(s); // colormap type - if( tga_colormap_type > 1 ) { - stbi__rewind(s); - return 0; // only RGB or indexed allowed - } - tga_image_type = stbi__get8(s); // image type - if ( tga_colormap_type == 1 ) { // colormapped (paletted) image - if (tga_image_type != 1 && tga_image_type != 9) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip image x and y origin - tga_colormap_bpp = sz; - } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { - stbi__rewind(s); - return 0; // only RGB or grey allowed, +/- RLE - } - stbi__skip(s,9); // skip colormap specification and image x/y origin - tga_colormap_bpp = 0; - } - tga_w = stbi__get16le(s); - if( tga_w < 1 ) { - stbi__rewind(s); - return 0; // test width +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int *is_rgb16) { + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) + *is_rgb16 = 0; + switch (bits_per_pixel) { + case 8: + return STBI_grey; + case 16: + if (is_grey) + return STBI_grey_alpha; + // fallthrough + case 15: + if (is_rgb16) + *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: + return bits_per_pixel / 8; + default: + return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) { + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, + tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if (tga_colormap_type > 1) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; } - tga_h = stbi__get16le(s); - if( tga_h < 1 ) { - stbi__rewind(s); - return 0; // test height + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__rewind(s); + return 0; } - tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits - if (tga_colormap_bpp != 0) { - if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { - // when using a colormap, tga_bits_per_pixel is the size of the indexes - // I don't think anything but 8 or 16bit indexes makes sense - stbi__rewind(s); - return 0; - } - tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); - } else { - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + stbi__skip(s, 4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ((tga_image_type != 2) && (tga_image_type != 3) && + (tga_image_type != 10) && (tga_image_type != 11)) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE } - if(!tga_comp) { + stbi__skip(s, 9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if (tga_w < 1) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if (tga_h < 1) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); return 0; } - if (x) *x = tga_w; - if (y) *y = tga_h; - if (comp) *comp = tga_comp; - return 1; // seems to have passed everything -} - -static int stbi__tga_test(stbi__context *s) -{ - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if ( tga_color_type == 1 ) { // colormapped (paletted) image - if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - stbi__skip(s,4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s,9); // skip colormap specification and image x/y origin - } - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp( + tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), + NULL); + } + if (!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) + *x = tga_w; + if (y) + *y = tga_h; + if (comp) + *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) { + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if (tga_color_type > 1) + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image + if (sz != 1 && sz != 9) + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin + } + if (stbi__get16le(s) < 1) + goto errorEnd; // test width + if (stbi__get16le(s) < 1) + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) + goto errorEnd; // for colormapped images, bpp is size of an index + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead + // of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) -{ - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255)/31); - out[1] = (stbi_uc)((g * 255)/31); - out[2] = (stbi_uc)((b * 255)/31); - - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. -} - -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16=0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char *tga_data; - unsigned char *tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO - - if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // do a tiny bit of precessing - if ( tga_image_type >= 8 ) - { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); - - // If I'm paletted, then I'll use the number of bits from the palette - if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - - if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - - // tga info - *x = tga_width; - *y = tga_height; - if (comp) *comp = tga_comp; - - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); - - tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); - - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset ); - - if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { - for (i=0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height -i - 1 : i; - stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if ( tga_indexed) - { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start ); - // load the palette - tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); - } - if (tga_rgb16) { - stbi_uc *pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i=0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc *out) { + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be + // swapped later + out[0] = (stbi_uc)((r * 255) / 31); + out[1] = (stbi_uc)((g * 255) / 31); + out[2] = (stbi_uc)((b * 255) / 31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16 = 0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused + // (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // do a tiny bit of precessing + if (tga_image_type >= 8) { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if (tga_indexed) + tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), + &tga_rgb16); + + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have + // ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) + *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = + (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) + return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset); + + if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { + for (i = 0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height - i - 1 : i; + stbi_uc *tga_row = tga_data + row * tga_width * tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if (tga_indexed) { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // load the data - for (i=0; i < tga_width * tga_height; ++i) - { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if ( tga_is_RLE ) - { - if ( RLE_count == 0 ) - { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if ( !RLE_repeating ) - { - read_next_pixel = 1; - } - } else - { - read_next_pixel = 1; - } - // OK, if I need to read a pixel, do it now - if ( read_next_pixel ) - { - // load however much data we did have - if ( tga_indexed ) - { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if ( pal_idx >= tga_palette_len ) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx+j]; - } - } else if(tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel - - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i*tga_comp+j] = raw_data[j]; - // in case we're in RLE mode, keep counting down - --RLE_count; + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start); + // load the palette + tga_palette = + (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); } - // do I need to invert the image? - if ( tga_inverted ) - { - for (j = 0; j*2 < tga_height; ++j) - { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) - { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } - } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i = 0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // clear my palette, if I had one - if ( tga_palette != NULL ) - { - STBI_FREE( tga_palette ); + } + // load the data + for (i = 0; i < tga_width * tga_height; ++i) { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if (tga_is_RLE) { + if (RLE_count == 0) { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if (!RLE_repeating) { + read_next_pixel = 1; + } + } else { + read_next_pixel = 1; } - } + // OK, if I need to read a pixel, do it now + if (read_next_pixel) { + // load however much data we did have + if (tga_indexed) { + // read in index, then perform the lookup + int pal_idx = + (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if (pal_idx >= tga_palette_len) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx + j]; + } + } else if (tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) - { - unsigned char* tga_pixel = tga_data; - for (i=0; i < tga_width * tga_height; ++i) - { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i * tga_comp + j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if (tga_inverted) { + for (j = 0; j * 2 < tga_height; ++j) { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } } - } + } + // clear my palette, if I had one + if (tga_palette != NULL) { + STBI_FREE(tga_palette); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) { + unsigned char *tga_pixel = tga_data; + for (i = 0; i < tga_width * tga_height; ++i) { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, + tga_height); - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = - tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = + tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; } #endif // ************************************************************************************************* -// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, +// tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s) -{ - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; -} - -static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) -{ - int count, nleft, len; - - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; - } +static int stbi__psd_test(stbi__context *s) { + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) { + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) + return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; } - } - - return 1; -} + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) + return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w, h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", + "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", + "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for + // other modes.) + stbi__skip(s, stbi__get32be(s)); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s)); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s)); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", + "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *)stbi__malloc(4 * w * h); + + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w * h; + + // Initialize the data to zero. + // memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes + // literally. Else if n is between -127 and -1 inclusive, copy the next + // byte -n+1 times. Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row + // in the data, which we're going to just skip. + stbi__skip(s, h * channelCount * 2); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out + channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w,h; - stbi_uc *out; - STBI_NOTUSED(ri); - - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); - - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - - // Skip 6 reserved bytes. - stbi__skip(s, 6 ); - - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); - - if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s,stbi__get32be(s) ); - - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s) ); - - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s) ); - - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); - - // Create the destination image. - - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *) stbi__malloc(4 * w*h); - - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w*h; - - // Initialize the data to zero. - //memset( out, 0, pixelCount * 4 ); - - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop - - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2 ); - - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc *p; - - p = out+channel; - if (channel >= channelCount) { - // Fill this channel with default data. + } else { + // We're at the raw image data. It's each channel in order (Red, Green, + // Blue, Alpha, ...) where each channel consists of an 8-bit (or 16-bit) + // value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out + channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16)stbi__get16be(s); + } else { + stbi_uc *p = out + channel; + if (bitdepth == 16) { // input bpc for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); - } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } - } + *p = (stbi_uc)(stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } } - - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc *p = out+channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } - } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16) stbi__get16be(s); - } else { - stbi_uc *p = out+channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc) (stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } - } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i = 0; i < w * h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *)out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); + pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); + pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); + } } - } - - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i=0; i < w*h; ++i) { - stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); - pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); - pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); - } - } - } else { - for (i=0; i < w*h; ++i) { - unsigned char *pixel = out + 4*i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); - pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); - pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); - } - } + } else { + for (i = 0; i < w * h; ++i) { + unsigned char *pixel = out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); + pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); + pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); + } } - } - - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - - if (comp) *comp = 4; - *y = h; - *x = w; - - return out; + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, + w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + if (comp) + *comp = 4; + *y = h; + *x = w; + + return out; } #endif @@ -6152,215 +6901,222 @@ static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context *s,const char *str) -{ - int i; - for (i=0; i<4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s, const char *str) { + int i; + for (i = 0; i < 4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context *s) -{ - int i; +static int stbi__pic_test_core(stbi__context *s) { + int i; - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) + return 0; - for(i=0;i<84;++i) - stbi__get8(s); + for (i = 0; i < 84; ++i) + stbi__get8(s); - if (!stbi__pic_is4(s,"PICT")) - return 0; + if (!stbi__pic_is4(s, "PICT")) + return 0; - return 1; + return 1; } -typedef struct -{ - stbi_uc size,type,channel; +typedef struct { + stbi_uc size, type, channel; } stbi__pic_packet; -static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) -{ - int mask=0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) { + int mask = 0x80, i; - for (i=0; i<4; ++i, mask>>=1) { - if (channel & mask) { - if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); - dest[i]=stbi__get8(s); - } - } + for (i = 0; i < 4; ++i, mask >>= 1) { + if (channel & mask) { + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "PIC file too short"); + dest[i] = stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) -{ - int mask=0x80,i; +static void stbi__copyval(int channel, stbi_uc *dest, const stbi_uc *src) { + int mask = 0x80, i; - for (i=0;i<4; ++i, mask>>=1) - if (channel&mask) - dest[i]=src[i]; + for (i = 0; i < 4; ++i, mask >>= 1) + if (channel & mask) + dest[i] = src[i]; } -static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) -{ - int act_comp=0,num_packets=0,y,chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s, int width, int height, + int *comp, stbi_uc *result) { + int act_comp = 0, num_packets = 0, y, chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data - // for the same channel in multiple packets. - do { - stbi__pic_packet *packet; + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return stbi__errpuc("bad format","too many packets"); + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return stbi__errpuc("bad format", "too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); - if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (reading packets)"); + if (packet->size != 8) + return stbi__errpuc("bad format", "packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for(y=0; ytype) { - default: - return stbi__errpuc("bad format","packet has bad compression type"); + switch (packet->type) { + default: + return stbi__errpuc("bad format", "packet has bad compression type"); - case 0: {//uncompressed - int x; + case 0: { // uncompressed + int x; - for(x=0;xchannel,dest)) - return 0; - break; - } + for (x = 0; x < width; ++x, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + break; + } - case 1://Pure RLE - { - int left=width, i; - - while (left>0) { - stbi_uc count,value[4]; - - count=stbi__get8(s); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - - if (count > left) - count = (stbi_uc) left; - - if (!stbi__readval(s,packet->channel,value)) return 0; - - for(i=0; ichannel,dest,value); - left -= count; - } - } - break; - - case 2: {//Mixed RLE - int left=width; - while (left>0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); - - if (count >= 128) { // Repeated - stbi_uc value[4]; - - if (count==128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file","scanline overrun"); - - if (!stbi__readval(s,packet->channel,value)) - return 0; - - for(i=0;ichannel,dest,value); - } else { // Raw - ++count; - if (count>left) return stbi__errpuc("bad file","scanline overrun"); - - for(i=0;ichannel,dest)) - return 0; - } - left-=count; - } - break; - } - } + case 1: // Pure RLE + { + int left = width, i; + + while (left > 0) { + stbi_uc count, value[4]; + + count = stbi__get8(s); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pure read count)"); + + if (count > left) + count = (stbi_uc)left; + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + left -= count; + } + } break; + + case 2: { // Mixed RLE + int left = width; + while (left > 0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", + "file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count == 128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + } else { // Raw + ++count; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + for (i = 0; i < count; ++i, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + } + left -= count; + } + break; + } } - } + } + } - return result; + return result; } -static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) -{ - stbi_uc *result; - int i, x,y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s, int *px, int *py, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *result; + int i, x, y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) comp = &internal_comp; + if (!comp) + comp = &internal_comp; - for (i=0; i<92; ++i) - stbi__get8(s); + for (i = 0; i < 92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) + return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); //skip `ratio' - stbi__get16be(s); //skip `fields' - stbi__get16be(s); //skip `pad' + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); - memset(result, 0xff, x*y*4); + // intermediate buffer is RGBA + result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); + memset(result, 0xff, x * y * 4); - if (!stbi__pic_load_core(s,x,y,comp, result)) { - STBI_FREE(result); - result=0; - } - *px = x; - *py = y; - if (req_comp == 0) req_comp = *comp; - result=stbi__convert_format(result,4,req_comp,x,y); + if (!stbi__pic_load_core(s, x, y, comp, result)) { + STBI_FREE(result); + result = 0; + } + *px = x; + *py = y; + if (req_comp == 0) + req_comp = *comp; + result = stbi__convert_format(result, 4, req_comp, x, y); - return result; + return result; } -static int stbi__pic_test(stbi__context *s) -{ - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) { + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6368,514 +7124,539 @@ static int stbi__pic_test(stbi__context *s) // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct -{ - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct { + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct -{ - int w,h; - stbi_uc *out; // output buffer (always 4 components) - stbi_uc *background; // The current "background" as far as a gif is concerned - stbi_uc *history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc *color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct { + int w, h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context *s) -{ - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') return 0; - if (stbi__get8(s) != 'a') return 0; - return 1; -} - -static int stbi__gif_test(stbi__context *s) -{ - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; -} - -static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) -{ - int i; - for (i=0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } -} +static int stbi__gif_test_raw(stbi__context *s) { + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') + return 0; + if (stbi__get8(s) != 'a') + return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) { + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], + int num_entries, int transp) { + int i; + for (i = 0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, + int is_info) { + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') + return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') + return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + + if (comp != 0) + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the + // comments + + if (is_info) + return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) { + stbi__gif *g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind(s); + return 0; + } + if (x) + *x = g->w; + if (y) + *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) { + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) + return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) { + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) + return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc)init_code; + g->codes[init_code].suffix = (stbi_uc)init_code; + } + + // support no starting clear code + avail = clear + 2; + oldcode = -1; + + len = 0; + for (;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32)stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s, len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } -static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) -{ - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - version = stbi__get8(s); - if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + p->prefix = (stbi__int16)oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; + stbi__out_gif_code(g, (stbi__uint16)code); - if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } - if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image +// doesn't support it two back is the image from two frames ago, used for a very +// specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, + int req_comp, stbi_uc *two_back) { + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour + // (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp, 0)) + return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *)stbi__malloc(4 * pcount); + g->background = (stbi_uc *)stbi__malloc(4 * pcount); + g->history = (stbi_uc *)stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); - if (is_info) return 1; + // image is treated as "transparent" at the start - ie, nothing overwrites + // the current background; background colour is only used for pixels that + // are not rendered first frame, after that "background" color refers to the + // color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, + 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, + pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the + // old background + } - if (g->flags & 0x80) - stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } - return 1; -} + // background is what out is after the undoing of the previou frame; + memcpy(g->background, g->out, 4 * g->w * g->h); + } + + // clear my history; + memset(g->history, 0x00, + g->w * g->h); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; -static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind( s ); - return 0; - } - if (x) *x = g->w; - if (y) *y = g->h; - STBI_FREE(g); - return 1; -} + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; -static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) -{ - stbi_uc *p, *c; - int idx; - - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); - - if (g->cur_y >= g->max_y) return; - - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; - - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; - - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + g->lflags = stbi__get8(s); - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; } - } -} -static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) -{ - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw *p; - - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc) init_code; - g->codes[init_code].suffix = (stbi_uc) init_code; - } - - // support no starting clear code - avail = clear+2; - oldcode = -1; - - len = 0; - for(;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; - } - --len; - bits |= (stbi__int32) stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s,len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), + g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *)g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *)g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } + o = stbi__process_gif_raster(s, g); + if (!o) + return NULL; - p->prefix = (stbi__int16) oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = + 255; // just in case it was made transparent, undo that; It will + // be reset next frame if need be; + memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); + } + } + } - stbi__out_gif_code(g, (stbi__uint16) code); + return o; + } - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = + 10 * stbi__get16le( + s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } - } -} - -// this function is designed to support animated gifs, although stb_image doesn't support it -// two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) -{ - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); - - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *) stbi__malloc(4 * pcount); - g->background = (stbi_uc *) stbi__malloc(4 * pcount); - g->history = (stbi_uc *) stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); - - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; - - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); } + break; + } - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); - } - } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); - } - } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers - // background is what out is after the undoing of the previou frame; - memcpy( g->background, g->out, 4 * g->w * g->h ); - } - - // clear my history; - memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame - - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { - stbi__int32 x, y, w, h; - stbi_uc *o; - - x = stbi__get16le(s); - y = stbi__get16le(s); - w = stbi__get16le(s); - h = stbi__get16le(s); - if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); - - g->line_size = g->w * 4; - g->start_x = x * 4; - g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; - - // if the width of the specified rectangle is 0, that means - // we may not see *any* pixels or the image is malformed; - // to make sure this is caught, move the current y down to - // max_y (which is what out_gif_code checks). - if (w == 0) - g->cur_y = g->max_y; - - g->lflags = stbi__get8(s); - - if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; - } else { - g->step = g->line_size; - g->parse = 0; - } + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp) { + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } - if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *) g->lpal; - } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *) g->pal; - } else - return stbi__errpuc("missing color table", "Corrupt GIF"); - - o = stbi__process_gif_raster(s, g); - if (!o) return NULL; - - // if this was the first frame, - pcount = g->w * g->h; - if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); - } - } - } + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = + (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + if (NULL == tmp) { + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + return stbi__errpuc("outofmem", "Out of memory"); + } else { + out = (stbi_uc *)tmp; + out_size = layers * stride; + } + + if (delays) { + *delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, + sizeof(int) * layers); + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc *)stbi__malloc(layers * stride); + out_size = layers * stride; + if (delays) { + *delays = (int *)stbi__malloc(layers * sizeof(int)); + delays_size = layers * sizeof(int); + } + } + memcpy(out + ((layers - 1) * stride), u, stride); + if (layers >= 2) { + two_back = out - 2 * stride; + } - return o; - } - - case 0x21: // Comment Extension. - { - int len; - int ext = stbi__get8(s); - if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } - } - while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); - } - break; - } + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); - case 0x3B: // gif stream termination code - return (stbi_uc *) s; // using '1' causes warning on some compilers + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - default: - return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } -} + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc *u = 0; - stbi_uc *out = 0; - stbi_uc *two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; - - if (out) { - void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); - if (NULL == tmp) { - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); - return stbi__errpuc("outofmem", "Out of memory"); - } - else { - out = (stbi_uc*) tmp; - out_size = layers * stride; - } - - if (delays) { - *delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc*)stbi__malloc( layers * stride ); - out_size = layers * stride; - if (delays) { - *delays = (int*) stbi__malloc( layers * sizeof(int) ); - delays_size = layers * sizeof(int); - } - } - memcpy( out + ((layers - 1) * stride), u, stride ); - if (layers >= 2) { - two_back = out - 2 * stride; - } +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); - if (delays) { - (*delays)[layers - 1U] = g.delay; - } - } - } while (u != 0); + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + return u; } -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); - - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } - - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); - - return u; -} - -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) -{ - return stbi__gif_info_raw(s,x,y,comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) { + return stbi__gif_info_raw(s, x, y, comp); } #endif @@ -6883,396 +7664,434 @@ static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context *s, const char *signature) -{ - int i; - for (i=0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; -} - -static int stbi__hdr_test(stbi__context* s) -{ - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if(!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; -} - -#define STBI__HDR_BUFLEN 1024 -static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) -{ - int len=0; - char c = '\0'; - - c = (char) stbi__get8(z); - - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN-1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) { + int i; + for (i = 0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context *s) { + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if (!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) { + int len = 0; + char c = '\0'; + + c = (char)stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN - 1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char)stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) { + if (input[3] != 0) { + float f1; + // Exponent + f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) + output[1] = 1; + if (req_comp == 4) + output[3] = 1; + } else { + switch (req_comp) { + case 4: + output[3] = 1; /* fallthrough */ + case 3: + output[0] = output[1] = output[2] = 0; + break; + case 2: + output[1] = 1; /* fallthrough */ + case 1: + output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1, c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s, buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && + strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) + return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int)strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) + *comp = 3; + if (req_comp == 0) + req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = + (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if (width < 8 || width >= 32768) { + // Read flat data + for (j = 0; j < height; ++j) { + for (i = 0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, + req_comp); } - c = (char) stbi__get8(z); - } - - buffer[len] = 0; - return buffer; -} + } + } else { + // Read RLE-encoded data + scanline = NULL; -static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) -{ - if ( input[3] != 0 ) { - float f1; - // Exponent - f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a + // decoded pixel (note this can't be a valid pixel--one of RGB must be + // >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc)c1; + rgbe[1] = (stbi_uc)c2; + rgbe[2] = (stbi_uc)len; + rgbe[3] = (stbi_uc)stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense } - if (req_comp == 2) output[1] = 1; - if (req_comp == 4) output[3] = 1; - } else { - switch (req_comp) { - case 4: output[3] = 1; /* fallthrough */ - case 3: output[0] = output[1] = output[2] = 0; - break; - case 2: output[1] = 1; /* fallthrough */ - case 1: output[0] = 0; - break; + len <<= 8; + len |= stbi__get8(s); + if (len != width) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - } -} - -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int width, height; - stbi_uc *scanline; - float *hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1,c2, z; - const char *headerToken; - STBI_NOTUSED(ri); - - // Check identifier - headerToken = stbi__hdr_gettoken(s,buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); - - // Parse header - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); - - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int) strtol(token, NULL, 10); - - if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - - *x = width; - *y = height; - - if (comp) *comp = 3; - if (req_comp == 0) req_comp = 3; - - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); - - // Read data - hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); - - // Load image data - // image data is stored as some number of sca - if ( width < 8 || width >= 32768) { - // Read flat data - for (j=0; j < height; ++j) { - for (i=0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } + if (scanline == NULL) { + scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } } - } else { - // Read RLE-encoded data - scanline = NULL; - - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc) c1; - rgbe[1] = (stbi_uc) c2; - rgbe[2] = (stbi_uc) len; - rgbe[3] = (stbi_uc) stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - if (scanline == NULL) { - scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - for (i=0; i < width; ++i) - stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } } - if (scanline) - STBI_FREE(scanline); - } - - return hdr_data; -} - -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int dummy; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (stbi__hdr_test(s) == 0) { - stbi__rewind( s ); - return 0; - } - - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) { - stbi__rewind( s ); - return 0; - } - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *y = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *x = (int) strtol(token, NULL, 10); - *comp = 3; - return 1; + for (i = 0; i < width; ++i) + stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, + scanline + i * 4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind(s); + return 0; + } + + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) { + stbi__rewind(s); + return 0; + } + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *y = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *x = (int)strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) -{ - void *p; - stbi__bmp_data info; - - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - stbi__rewind( s ); - if (p == NULL) - return 0; - if (x) *x = s->img_x; - if (y) *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) { + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + stbi__rewind(s); + if (p == NULL) + return 0; + if (x) + *x = s->img_x; + if (y) + *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) -{ - int channelCount, dummy, depth; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind( s ); - return 0; - } - *comp = 4; - return 1; -} - -static int stbi__psd_is16(stbi__context *s) -{ - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - (void) stbi__get32be(s); - (void) stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) { + int channelCount, dummy, depth; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind(s); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) { + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + (void)stbi__get32be(s); + (void)stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind(s); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) -{ - int act_comp=0,num_packets=0,chained,dummy; - stbi__pic_packet packets[10]; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { - stbi__rewind(s); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) { + int act_comp = 0, num_packets = 0, chained, dummy; + stbi__pic_packet packets[10]; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) return 0; - } - stbi__skip(s, 88); + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind( s); + if (stbi__at_eof(s)) { + stbi__rewind(s); return 0; - } - if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind( s ); + } + if (packet->size != 8) { + stbi__rewind(s); return 0; - } - - stbi__skip(s, 8); - - do { - stbi__pic_packet *packet; - - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return 0; - - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; - - if (stbi__at_eof(s)) { - stbi__rewind( s ); - return 0; - } - if (packet->size != 8) { - stbi__rewind( s ); - return 0; - } - } while (chained); + } + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); + *comp = (act_comp & 0x10 ? 4 : 3); - return 1; + return 1; } #endif @@ -7290,257 +8109,266 @@ static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s) -{ - char p, t; - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) { + char p, t; + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + return 1; } -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - STBI_NOTUSED(ri); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + STBI_NOTUSED(ri); - if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) - return 0; + if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) + return 0; - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; - if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "PNM too large"); + if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "PNM too large"); - out = (stbi_uc *) stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - stbi__getn(s, out, s->img_n * s->img_x * s->img_y); + out = (stbi_uc *)stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + stbi__getn(s, out, s->img_n * s->img_x * s->img_y); - if (req_comp && req_comp != s->img_n) { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - return out; + if (req_comp && req_comp != s->img_n) { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + return out; } -static int stbi__pnm_isspace(char c) -{ - return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +static int stbi__pnm_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; } -static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) -{ - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char) stbi__get8(s); +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) { + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char)stbi__get8(s); - if (stbi__at_eof(s) || *c != '#') - break; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') + *c = (char)stbi__get8(s); + } } -static int stbi__pnm_isdigit(char c) -{ - return c >= '0' && c <= '9'; +static int stbi__pnm_isdigit(char c) { + return c >= '0' && c <= '9'; } -static int stbi__pnm_getinteger(stbi__context *s, char *c) -{ - int value = 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) { + int value = 0; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value*10 + (*c - '0'); - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value * 10 + (*c - '0'); + *c = (char)stbi__get8(s); + } - return value; + return value; } -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) -{ - int maxv, dummy; - char c, p, t; +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) { + int maxv, dummy; + char c, p, t; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; - stbi__rewind(s); + stbi__rewind(s); - // Get identifier - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } + // Get identifier + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = + (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - c = (char) stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); + c = (char)stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width - stbi__pnm_skip_whitespace(s, &c); + *x = stbi__pnm_getinteger(s, &c); // read width + stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height - stbi__pnm_skip_whitespace(s, &c); + *y = stbi__pnm_getinteger(s, &c); // read height + stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value + maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 255) - return stbi__err("max value > 255", "PPM image not 8-bit"); - else - return 1; + if (maxv > 255) + return stbi__err("max value > 255", "PPM image not 8-bit"); + else + return 1; } #endif -static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) -{ - #ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) return 1; - #endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) { +#ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) + return 1; +#endif - // test tga last because it's a crappy test! - #ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; - #endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +// test tga last because it's a crappy test! +#ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; +#endif + return stbi__err("unknown image type", + "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context *s) -{ - #ifndef STBI_NO_PNG - if (stbi__png_is16(s)) return 1; - #endif +static int stbi__is_16_main(stbi__context *s) { +#ifndef STBI_NO_PNG + if (stbi__png_is16(s)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) + return 1; +#endif - return 0; + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_info_from_file(f, x, y, comp); - fclose(f); - return result; -} - -STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s,x,y,comp); - fseek(f,pos,SEEK_SET); - return r; -} - -STBIDEF int stbi_is_16_bit(char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_is_16_bit_from_file(f); - fclose(f); - return result; -} - -STBIDEF int stbi_is_16_bit_from_file(FILE *f) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f,pos,SEEK_SET); - return r; +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s, x, y, comp); + fseek(f, pos, SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f, pos, SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, + int *x, int *y, int *comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, + void *user) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION /* revision history: - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs - 2.19 (2018-02-11) fix warning - 2.18 (2018-01-30) fix warnings - 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and + platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix + warnings 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug 1-bit BMP *_is_16_bit api avoid warnings @@ -7555,13 +8383,11 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user warning fixes; disable run-time SSE detection on gcc; uniform handling of optional "return" values; thread-safe initialization of zlib tables - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) allocate large structures on the stack - remove white matting for transparent PSD - fix reported channel count for PNG & BMP - re-enable SSE2 in non-gcc 64-bit + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet + JPGs 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now 2.12 + (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 (2016-04-02) + allocate large structures on the stack remove white matting for transparent + PSD fix reported channel count for PNG & BMP re-enable SSE2 in non-gcc 64-bit support RGB-formatted JPEG read 16-bit PNGs (only as 8-bit) 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED @@ -7569,11 +8395,9 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 16-bit-per-pixel TGA (not bit-per-component) info() for TGA could break due to .hdr handling info() for BMP to shares code instead of sloppy parse - can use STBI_REALLOC_SIZED if allocator doesn't support realloc - code cleanup - 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA - 2.07 (2015-09-13) fix compiler warnings - partial animated GIF support + can use STBI_REALLOC_SIZED if allocator doesn't support + realloc code cleanup 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD + as RGBA 2.07 (2015-09-13) fix compiler warnings partial animated GIF support limited 16-bpc PSD support #ifdef unused functions bug with < 92 byte PIC,PNM,HDR,TGA @@ -7584,23 +8408,18 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user stbi_set_flip_vertically_on_load (nguillemot) fix NEON support; fix mingw support 2.02 (2015-01-19) fix incorrect assert, fix warning - 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 - 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG - 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) - progressive JPEG (stb) - PGM/PPM support (Ken Miller) - STBI_MALLOC,STBI_REALLOC,STBI_FREE + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit + without -msse2 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG 2.00 + (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) progressive + JPEG (stb) PGM/PPM support (Ken Miller) STBI_MALLOC,STBI_REALLOC,STBI_FREE GIF bugfix -- seemingly never worked STBI_NO_*, STBI_ONLY_* 1.48 (2014-12-14) fix incorrectly-named assert() - 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) - optimize PNG (ryg) - fix bug in interlaced PNG with user-specified channel count (stb) - 1.46 (2014-08-26) - fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG - 1.45 (2014-08-16) - fix MSVC-ARM internal compiler error by wrapping malloc - 1.44 (2014-08-07) + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar + Cornut & stb) optimize PNG (ryg) fix bug in interlaced PNG with + user-specified channel count (stb) 1.46 (2014-08-26) fix broken tRNS chunk + (colorkey-style transparency) in non-paletted PNG 1.45 (2014-08-16) fix + MSVC-ARM internal compiler error by wrapping malloc 1.44 (2014-08-07) various warning fixes from Ronny Chevalier 1.43 (2014-07-15) fix MSVC-only compiler problem in code changed in 1.42 @@ -7609,73 +8428,48 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user fixes to stbi__cleanup_jpeg path added STBI_ASSERT to avoid requiring assert.h 1.41 (2014-06-25) - fix search&replace from 1.36 that messed up comments/error messages - 1.40 (2014-06-22) - fix gcc struct-initialization warning - 1.39 (2014-06-15) - fix to TGA optimization when req_comp != number of components in TGA; - fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) - add support for BMP version 5 (more ignored fields) - 1.38 (2014-06-06) - suppress MSVC warnings on integer casts truncating values - fix accidental rename of 'skip' field of I/O - 1.37 (2014-06-04) - remove duplicate typedef - 1.36 (2014-06-03) - convert to header file single-file library - if de-iphone isn't set, load iphone images color-swapped instead of returning NULL - 1.35 (2014-05-27) - various warnings - fix broken STBI_SIMD path - fix bug where stbi_load_from_file no longer left file pointer in correct place - fix broken non-easy path for 32-bit BMP (possibly never used) - TGA optimization by Arseny Kapoulkine - 1.34 (unknown) - use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case - 1.33 (2011-07-14) - make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements - 1.32 (2011-07-13) - support for "info" function for all supported filetypes (SpartanJ) - 1.31 (2011-06-20) - a few more leak fixes, bug in PNG handling (SpartanJ) - 1.30 (2011-06-11) - added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + fix search&replace from 1.36 that messed up comments/error + messages 1.40 (2014-06-22) fix gcc struct-initialization warning 1.39 + (2014-06-15) fix to TGA optimization when req_comp != number of components in + TGA; fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my + test suite) add support for BMP version 5 (more ignored fields) 1.38 + (2014-06-06) suppress MSVC warnings on integer casts truncating values fix + accidental rename of 'skip' field of I/O 1.37 (2014-06-04) remove duplicate + typedef 1.36 (2014-06-03) convert to header file single-file library if + de-iphone isn't set, load iphone images color-swapped instead of returning + NULL 1.35 (2014-05-27) various warnings fix broken STBI_SIMD path fix bug + where stbi_load_from_file no longer left file pointer in correct place fix + broken non-easy path for 32-bit BMP (possibly never used) TGA optimization by + Arseny Kapoulkine 1.34 (unknown) use STBI_NOTUSED in + stbi__resample_row_generic(), fix one more leak in tga failure case 1.33 + (2011-07-14) make stbi_is_hdr work in STBI_NO_HDR (as specified), minor + compiler-friendly improvements 1.32 (2011-07-13) support for "info" function + for all supported filetypes (SpartanJ) 1.31 (2011-06-20) a few more leak + fixes, bug in PNG handling (SpartanJ) 1.30 (2011-06-11) added ability to + load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway - error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) - fix inefficiency in decoding 32-bit BMP (David Woo) - 1.29 (2010-08-16) - various warning fixes from Aurelien Pocheville - 1.28 (2010-08-01) - fix bug in GIF palette transparency (SpartanJ) - 1.27 (2010-08-01) - cast-to-stbi_uc to fix warnings - 1.26 (2010-07-24) - fix bug in file buffering for PNG reported by SpartanJ - 1.25 (2010-07-17) - refix trans_data warning (Won Chun) - 1.24 (2010-07-12) - perf improvements reading from files on platforms with lock-heavy fgetc() - minor perf improvements for jpeg - deprecated type-specific functions so we'll get feedback if they're needed - attempt to fix trans_data warning (Won Chun) - 1.23 fixed bug in iPhone support - 1.22 (2010-07-10) - removed image *writing* support - stbi_info support from Jetro Lauha - GIF support from Jean-Marc Lienher + removed support for installable file formats (stbi_loader) -- + would have been broken for IO callbacks anyway error cases in bmp and tga + give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in + decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from + Aurelien Pocheville 1.28 (2010-08-01) fix bug in GIF palette transparency + (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings 1.26 + (2010-07-24) fix bug in file buffering for PNG reported by SpartanJ 1.25 + (2010-07-17) refix trans_data warning (Won Chun) 1.24 (2010-07-12) perf + improvements reading from files on platforms with lock-heavy fgetc() minor + perf improvements for jpeg deprecated type-specific functions so we'll get + feedback if they're needed attempt to fix trans_data warning (Won Chun) 1.23 + fixed bug in iPhone support 1.22 (2010-07-10) removed image *writing* + support stbi_info support from Jetro Lauha GIF support from Jean-Marc Lienher iPhone PNG-extensions from James Brown - warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) - 1.21 fix use of 'stbi_uc' in header (reported by jon blow) - 1.20 added support for Softimage PIC, by Tom Seddon - 1.19 bug in interlaced PNG corruption check (found by ryg) - 1.18 (2008-08-02) - fix a threading bug (local mutable static) - 1.17 support interlaced PNG - 1.16 major bugfix - stbi__convert_format converted one too many pixels - 1.15 initialize some fields for thread safety - 1.14 fix threadsafe conversion bug - header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. + Janez (U+017D)emva) 1.21 fix use of 'stbi_uc' in header (reported by jon + blow) 1.20 added support for Softimage PIC, by Tom Seddon 1.19 bug in + interlaced PNG corruption check (found by ryg) 1.18 (2008-08-02) fix a + threading bug (local mutable static) 1.17 support interlaced PNG 1.16 + major bugfix - stbi__convert_format converted one too many pixels 1.15 + initialize some fields for thread safety 1.14 fix threadsafe conversion + bug header-file-only version (#define STBI_HEADER_FILE_ONLY before including) 1.13 threadsafe 1.12 const qualifiers in the API 1.11 Support installable IDCT, colorspace conversion routines @@ -7685,15 +8479,14 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz 1.07 attempt to fix C++ warning/errors again 1.06 attempt to fix C++ warning/errors again - 1.05 fix TGA loading to return correct *comp and use good luminance calc - 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free - 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR - 1.02 support for (subset of) HDR files, float interface for preferred access to them - 1.01 fix bug: possible bug in handling right-side up bmps... not sure - fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all - 1.00 interface to zlib that skips zlib header - 0.99 correct handling of alpha in palette - 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 1.05 fix TGA loading to return correct *comp and use good luminance + calc 1.04 default float alpha is 1, not 255; use 'void *' for + stbi_image_free 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR 1.02 support + for (subset of) HDR files, float interface for preferred access to them 1.01 + fix bug: possible bug in handling right-side up bmps... not sure fix bug: the + stbi__bmp_load() and stbi__tga_load() functions didn't work at all 1.00 + interface to zlib that skips zlib header 0.99 correct handling of alpha in + palette 0.98 TGA loader by lonesock; dynamically add loaders (untested) 0.97 jpeg errors on too large a file; also catch another malloc failure 0.96 fix detection of invalid v value - particleman@mollyrocket forum 0.95 during header scan, seek to markers in case of padding @@ -7706,8 +8499,8 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 0.60 fix compiling as c++ 0.59 fix warnings: merge Dave Moore's -Wall fixes 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian - 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available - 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but + less than 16 available 0.56 fix bug: zlib uncompressed mode len vs. nlen 0.55 fix bug: restart_interval not initialized to 0 0.54 allow NULL for 'int *comp' 0.53 fix bug in png 3->4; speedup png decoding @@ -7718,7 +8511,6 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user first released version */ - /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. diff --git a/sample_programs/ml_sample_programs/vision_models/mnist/compile.sh b/sample_programs/ml_sample_programs/vision_models/mnist/compile.sh index c2c9dd86..eff8d7e6 100755 --- a/sample_programs/ml_sample_programs/vision_models/mnist/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/mnist/compile.sh @@ -9,7 +9,7 @@ fi printf "\n[Compile Script]: Convert ONNX model to LLVM IR\n" python3 ../../../../tools/ExtendONNXModel.py --model_path ./model.onnx --output_model_path ./extendedmodel.onnx > expected_op_seq.txt -onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp +onnx-mlir --EmitLLVMIR extendedmodel.onnx --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp mlir-translate -mlir-to-llvmir extendedmodel.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/mnist/image.c b/sample_programs/ml_sample_programs/vision_models/mnist/image.c index 3f915419..2057269e 100644 --- a/sample_programs/ml_sample_programs/vision_models/mnist/image.c +++ b/sample_programs/ml_sample_programs/vision_models/mnist/image.c @@ -169,7 +169,7 @@ void export_layer_output_to_json(OMTensorList *outputList, char* savefile, char* // Get properties of the tensor that you want to export to the JSON file int64_t rank = omTensorGetRank(omt); - int64_t *shape = omTensorGetShape(omt); + const int64_t *shape = omTensorGetShape(omt); int64_t numElements = (int64_t) (omTensorGetNumElems(omt) / shape[0]); float *dataBuf = (float *)omTensorGetDataPtr(omt); int64_t bufferSize = omTensorGetBufferSize(omt); diff --git a/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn-pytorch.py b/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn-pytorch.py index 31b0109f..ba280f12 100644 --- a/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn-pytorch.py +++ b/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn-pytorch.py @@ -46,9 +46,15 @@ def train(args, model, device, train_loader, optimizer, epoch): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) if args.dry_run: break @@ -61,42 +67,95 @@ def test(model, device, test_loader): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--dry-run', action='store_true', default=False, - help='quickly check a single pass') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - parser.add_argument('--save-onnx', action='store_true', default=False, - help='For Saving the current Model in onnx') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument( + "--save-onnx", + action="store_true", + default=False, + help="For Saving the current Model in onnx", + ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() @@ -104,24 +163,19 @@ def main(): device = torch.device("cuda" if use_cuda else "cpu") - train_kwargs = {'batch_size': args.batch_size} - test_kwargs = {'batch_size': args.test_batch_size} + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True} + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("../data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) model = Net().to(device) @@ -139,17 +193,20 @@ def main(): if args.save_onnx: x = torch.randn(args.test_batch_size, 1, 28, 28, requires_grad=True) - torch.onnx.export(model, # model being run - x.cuda(), # model input (or a tuple for multiple inputs) - "model.onnx", # where to save the model (can be a file or file-like object) - export_params=True, # store the trained parameter weights inside the model file - do_constant_folding=True, # whether to execute constant folding for optimization - input_names = ['input'], # the model's input names - output_names = ['output'], # the model's output names - dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes - 'output' : {0 : 'batch_size'}}) - - -if __name__ == '__main__': + torch.onnx.export( + model, # model being run + x.cuda(), # model input (or a tuple for multiple inputs) + "model.onnx", # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=["input"], # the model's input names + output_names=["output"], # the model's output names + dynamic_axes={ + "input": {0: "batch_size"}, # variable length axes + "output": {0: "batch_size"}, + }, + ) + + +if __name__ == "__main__": main() - diff --git a/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn.py b/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn.py index e42f9187..19d24174 100644 --- a/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn.py +++ b/sample_programs/ml_sample_programs/vision_models/mnist/mnist-cnn.py @@ -1,5 +1,6 @@ from tensorflow.keras import datasets, layers, models, losses, regularizers + def process_images(images): images = images.reshape((-1, 28, 28, 1)) images = images / 255.0 @@ -8,10 +9,7 @@ def process_images(images): def get_model(): model = models.Sequential() - model.add( - layers.Conv2D( - 32, (5, 5), activation="relu", input_shape=( - 28, 28, 1))) + model.add(layers.Conv2D(32, (5, 5), activation="relu", input_shape=(28, 28, 1))) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(64, (5, 5), activation="relu")) model.add(layers.MaxPooling2D((2, 2))) @@ -28,23 +26,24 @@ def get_model(): def main(): - #Load training data - (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data() + # Load training data + (train_images, train_labels), (test_images, test_labels) = ( + datasets.mnist.load_data() + ) train_images = process_images(train_images) test_images = process_images(test_images) - #Setup model + # Setup model model = get_model() modelname = "mnist-cnn" - #Train and test model + # Train and test model model.fit(train_images, train_labels, batch_size=100, epochs=5, verbose=2) - #Save model + # Save model filepath = modelname + ".tf" model.save(filepath) if __name__ == "__main__": main() - diff --git a/sample_programs/ml_sample_programs/vision_models/mnist/mnist-nn.py b/sample_programs/ml_sample_programs/vision_models/mnist/mnist-nn.py index 8d8851ed..fa8af439 100644 --- a/sample_programs/ml_sample_programs/vision_models/mnist/mnist-nn.py +++ b/sample_programs/ml_sample_programs/vision_models/mnist/mnist-nn.py @@ -1,14 +1,15 @@ -#Disable excessive logging +# Disable excessive logging import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -#Setup GPU environment +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +# Setup GPU environment import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, models, losses, regularizers import numpy as np -physical_devices = tf.config.list_physical_devices('GPU') +physical_devices = tf.config.list_physical_devices("GPU") for device in physical_devices: tf.config.experimental.set_memory_growth(device, True) @@ -21,13 +22,13 @@ def process_images(images): def get_model(): model = models.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(128, activation="relu"), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation="softmax"), - ] - ) + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation="softmax"), + ] + ) model.compile( optimizer="adam", @@ -38,26 +39,27 @@ def get_model(): def main(): - #Load training data - (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data() + # Load training data + (train_images, train_labels), (test_images, test_labels) = ( + datasets.mnist.load_data() + ) train_images = process_images(train_images) test_images = process_images(test_images) - #Setup model + # Setup model model = get_model() modelname = "mnist-nn" - #Train and test model + # Train and test model model.fit(train_images, train_labels, batch_size=100, epochs=5, verbose=2) test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0) print("Accuracy before faults for ", modelname, "is ", test_acc) - #Save model + # Save model filepath = modelname + ".tf" model.save(filepath) if __name__ == "__main__": main() - diff --git a/sample_programs/ml_sample_programs/vision_models/mnist/stb_image.h b/sample_programs/ml_sample_programs/vision_models/mnist/stb_image.h index accef483..5b891039 100644 --- a/sample_programs/ml_sample_programs/vision_models/mnist/stb_image.h +++ b/sample_programs/ml_sample_programs/vision_models/mnist/stb_image.h @@ -3,7 +3,8 @@ Do this: #define STB_IMAGE_IMPLEMENTATION - before you include this file in *one* C or C++ file to create the implementation. + before you include this file in *one* C or C++ file to create the +implementation. // i.e. it should look like this: #include ... @@ -13,15 +14,16 @@ #include "stb_image.h" You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. - And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using +malloc,realloc,free QUICK NOTES: Primarily of interest to game developers and other people who can avoid problematic images and only need the trivial interface - JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) - PNG 1/2/4/8/16-bit-per-channel + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as +stock IJG lib) PNG 1/2/4/8/16-bit-per-channel TGA (not sure what subset, if a subset) BMP non-1bpp, non-RLE @@ -50,25 +52,22 @@ RECENT REVISION HISTORY: 2.26 (2020-07-13) many minor fixes 2.25 (2020-02-02) fix warnings - 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically - 2.23 (2019-08-11) fix clang static analysis warning - 2.22 (2019-03-04) gif fixes, fix warnings - 2.21 (2019-02-25) fix typo in comment - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and +flip_vertically 2.23 (2019-08-11) fix clang static analysis warning 2.22 +(2019-03-04) gif fixes, fix warnings 2.21 (2019-02-25) fix typo in comment 2.20 +(2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix warnings 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings - 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes - 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 - RGB-format JPEG; remove white matting in PSD; - allocate large structures on the stack; - correct channel count for PNG & BMP - 2.10 (2016-01-22) avoid warning introduced in 2.09 - 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; +bugfixes 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE +detection on GCC 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for +Imagenet JPGs 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; +fixes 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 +(2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 RGB-format JPEG; remove +white matting in PSD; allocate large structures on the stack; correct channel +count for PNG & BMP 2.10 (2016-01-22) avoid warning introduced in 2.09 2.09 +(2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED See end of file for full revision history. @@ -86,38 +85,37 @@ RECENT REVISION HISTORY: github:urraka (animated gif) Junggon Kim (PNM comments) Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) socks-the-fox (16-bit PNG) - Jeremy Sawicki (handle all ImageNet JPGs) - Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Jeremy Sawicki (handle all ImageNet +JPGs) Optimizations & bugfixes Mikhail Morozov (1-bit BMP) Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) Arseny Kapoulkine John-Mark Allen Carmelo J Fdez-Aguera Bug & warning fixes - Marc LeBlanc David Woo Guillaume George Martins Mozeiko - Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski - Phil Jordan Dave Moore Roy Eltham - Hayaki Saito Nathan Reed Won Chun - Luke Graham Johan Duparc Nick Verigakis the Horde3D community - Thomas Ruf Ronny Chevalier github:rlyeh - Janez Zemva John Bartholomew Michal Cichon github:romigrou - Jonathan Blow Ken Hamada Tero Hanninen github:svdijk - Laurent Gomila Cort Stratton github:snagar - Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex - Cass Everitt Ryamond Barbiero github:grim210 - Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw - Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus - Josh Tobin Matthew Gregan github:poppolopoppo - Julian Raschke Gregory Mullen Christian Floisand github:darealshinji - Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 - Brad Weinberger Matvey Cherevko [reserved] - Luca Sas Alexander Veselov Zack Middleton [reserved] + Marc LeBlanc David Woo Guillaume George Martins +Mozeiko Christpher Lloyd Jerry Jansson Joseph Thomson Blazej +Dariusz Roszkowski Phil Jordan Dave Moore Roy +Eltham Hayaki Saito Nathan Reed Won Chun Luke Graham Johan +Duparc Nick Verigakis the Horde3D community Thomas Ruf Ronny +Chevalier github:rlyeh Janez Zemva John +Bartholomew Michal Cichon github:romigrou Jonathan Blow Ken +Hamada Tero Hanninen github:svdijk Laurent Gomila Cort +Stratton github:snagar Aruelien Pocheville Sergio Gonzalez Thibault +Reuille github:Zelex Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Matthew Gregan +github:poppolopoppo Julian Raschke Gregory Mullen Christian +Floisand github:darealshinji Baldur Karlsson Kevin Schmidt JR +Smith github:Michaelangel007 Brad Weinberger Matvey Cherevko +[reserved] Luca Sas Alexander Veselov Zack Middleton [reserved] Ryan C. Gordon [reserved] [reserved] DO NOT ADD YOUR NAME HERE - To add your name to the credits, pick a random blank space in the middle and fill it. - 80% of merge conflicts on stb PRs are due to people adding their name at the end - of the credits. + To add your name to the credits, pick a random blank space in the middle and +fill it. 80% of merge conflicts on stb PRs are due to people adding their name +at the end of the credits. */ #ifndef STBI_INCLUDE_STB_IMAGE_H @@ -136,14 +134,15 @@ RECENT REVISION HISTORY: // // ... process data if not NULL ... // // ... x = width, y = height, n = # 8-bit components per pixel ... // // ... replace '0' with '1'..'4' to force that many components per pixel -// // ... but 'n' will always be the number that it would have been if you said 0 -// stbi_image_free(data) +// // ... but 'n' will always be the number that it would have been if you +// said 0 stbi_image_free(data) // // Standard parameters: // int *x -- outputs image width in pixels // int *y -- outputs image height in pixels // int *channels_in_file -- outputs # of image components in image file -// int desired_channels -- if non-zero, # of image components requested in result +// int desired_channels -- if non-zero, # of image components requested in +// result // // The return value from an image loader is an 'unsigned char *' which points // to the pixel data, or NULL on an allocation failure or if the image is @@ -171,8 +170,8 @@ RECENT REVISION HISTORY: // and *x, *y, *channels_in_file will be unchanged. The function // stbi_failure_reason() can be queried for an extremely brief, end-user // unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS -// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly -// more user-friendly ones. +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get +// slightly more user-friendly ones. // // Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. // @@ -196,11 +195,12 @@ RECENT REVISION HISTORY: // 2. easy to maintain // 3. good performance // -// Sometimes I let "good performance" creep up in priority over "easy to maintain", -// and for best performance I may provide less-easy-to-use APIs that give higher -// performance, in addition to the easy-to-use ones. Nevertheless, it's important -// to keep in mind that from the standpoint of you, a client of this library, -// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// Sometimes I let "good performance" creep up in priority over "easy to +// maintain", and for best performance I may provide less-easy-to-use APIs that +// give higher performance, in addition to the easy-to-use ones. Nevertheless, +// it's important to keep in mind that from the standpoint of you, a client of +// this library, all you care about is #1 and #3, and stb libraries DO NOT +// emphasize #3 above all. // // Some secondary priorities arise directly from the first two, some of which // provide more explicit reasons why performance can't be emphasized. @@ -219,7 +219,8 @@ RECENT REVISION HISTORY: // overhead. // // The three functions you must define are "read" (reads some bytes of data), -// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the +// end). // // =========================================================================== // @@ -247,10 +248,11 @@ RECENT REVISION HISTORY: // HDR image support (disable by defining STBI_NO_HDR) // // stb_image supports loading HDR images in general, and currently the Radiance -// .HDR file format specifically. You can still load any file through the existing -// interface; if you attempt to load an HDR file, it will be automatically remapped -// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; -// both of these constants can be reconfigured through this interface: +// .HDR file format specifically. You can still load any file through the +// existing interface; if you attempt to load an HDR file, it will be +// automatically remapped to LDR, assuming gamma 2.2 and an arbitrary scale +// factor defaulting to 1; both of these constants can be reconfigured through +// this interface: // // stbi_hdr_to_ldr_gamma(2.2f); // stbi_hdr_to_ldr_scale(1.0f); @@ -342,14 +344,13 @@ RECENT REVISION HISTORY: #define STBI_VERSION 1 -enum -{ - STBI_default = 0, // only used for desired_channels +enum { + STBI_default = 0, // only used for desired_channels - STBI_grey = 1, - STBI_grey_alpha = 2, - STBI_rgb = 3, - STBI_rgb_alpha = 4 + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 }; #include @@ -377,11 +378,13 @@ extern "C" { // load image by filename, open file, or memory buffer // -typedef struct -{ - int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read - void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative - int (*eof) (void *user); // returns nonzero if we are at end of file/data +typedef struct { + int (*read)(void *user, char *data, + int size); // fill 'data' with 'size' bytes. return number of + // bytes actually read + void (*skip)(void *user, int n); // skip the next 'n' bytes, or 'unget' the + // last -n bytes if negative + int (*eof)(void *user); // returns nonzero if we are at end of file/data } stbi_io_callbacks; //////////////////////////////////// @@ -389,21 +392,33 @@ typedef struct // 8-bits-per-channel interface // -STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); -// for stbi_load_from_file, file pointer is left pointing immediately after image +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after +// image #endif #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp); #endif #ifdef STBI_WINDOWS_UTF8 -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input); #endif //////////////////////////////////// @@ -411,12 +426,20 @@ STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wch // 16-bits-per-channel interface // -STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); #ifndef STBI_NO_STDIO -STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); -STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); #endif //////////////////////////////////// @@ -424,83 +447,102 @@ STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_i // float-per-channel interface // #ifndef STBI_NO_LINEAR - STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *channels_in_file, + int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels); - #ifndef STBI_NO_STDIO - STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); - STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); - #endif +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, + int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, + int *channels_in_file, + int desired_channels); +#endif #endif #ifndef STBI_NO_HDR - STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); - STBIDEF void stbi_hdr_to_ldr_scale(float scale); +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); #endif // STBI_NO_HDR #ifndef STBI_NO_LINEAR - STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); - STBIDEF void stbi_ldr_to_hdr_scale(float scale); +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); #endif // STBI_NO_LINEAR // stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename); -STBIDEF int stbi_is_hdr_from_file(FILE *f); +STBIDEF int stbi_is_hdr(char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); #endif // STBI_NO_STDIO - // get a VERY brief reason for failure // on most compilers (and ALL modern mainstream compilers) this is threadsafe -STBIDEF const char *stbi_failure_reason (void); +STBIDEF const char *stbi_failure_reason(void); // free the loaded image -- this is just free() -STBIDEF void stbi_image_free (void *retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load); // get image dimensions & components without fully decoding -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, + int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, + void *user); #ifndef STBI_NO_STDIO -STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); -STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); -STBIDEF int stbi_is_16_bit (char const *filename); -STBIDEF int stbi_is_16_bit_from_file(FILE *f); +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit(char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); #endif - - // for image formats that explicitly notate that they have premultiplied alpha, // we just return the colors as stored in the file. set this flag to force // unpremultiplication. results are undefined if the unpremultiply overflow. -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); // indicate whether we should process iphone images back to canonical format, // or just pass them through "as-is" STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); -// flip the image vertically, so the first pixel in the output array is the bottom left +// flip the image vertically, so the first pixel in the output array is the +// bottom left STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); -// as above, but only applies to images loaded on the thread that calls the function -// this function is only available if your compiler supports thread-local variables; -// calling it will fail to link if your compiler doesn't -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); +// as above, but only applies to images loaded on the thread that calls the +// function this function is only available if your compiler supports +// thread-local variables; calling it will fail to link if your compiler doesn't +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); // ZLIB client - used by PNG, available for other purposes -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header); STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); - -STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, + int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen); #ifdef __cplusplus } @@ -513,52 +555,53 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #ifdef STB_IMAGE_IMPLEMENTATION -#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ - || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ - || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ - || defined(STBI_ONLY_ZLIB) - #ifndef STBI_ONLY_JPEG - #define STBI_NO_JPEG - #endif - #ifndef STBI_ONLY_PNG - #define STBI_NO_PNG - #endif - #ifndef STBI_ONLY_BMP - #define STBI_NO_BMP - #endif - #ifndef STBI_ONLY_PSD - #define STBI_NO_PSD - #endif - #ifndef STBI_ONLY_TGA - #define STBI_NO_TGA - #endif - #ifndef STBI_ONLY_GIF - #define STBI_NO_GIF - #endif - #ifndef STBI_ONLY_HDR - #define STBI_NO_HDR - #endif - #ifndef STBI_ONLY_PIC - #define STBI_NO_PIC - #endif - #ifndef STBI_ONLY_PNM - #define STBI_NO_PNM - #endif -#endif - -#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) -#define STBI_NO_ZLIB +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || \ + defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || \ + defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ + defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) +#ifndef STBI_ONLY_JPEG +#define STBI_NO_JPEG +#endif +#ifndef STBI_ONLY_PNG +#define STBI_NO_PNG +#endif +#ifndef STBI_ONLY_BMP +#define STBI_NO_BMP +#endif +#ifndef STBI_ONLY_PSD +#define STBI_NO_PSD +#endif +#ifndef STBI_ONLY_TGA +#define STBI_NO_TGA +#endif +#ifndef STBI_ONLY_GIF +#define STBI_NO_GIF +#endif +#ifndef STBI_ONLY_HDR +#define STBI_NO_HDR +#endif +#ifndef STBI_ONLY_PIC +#define STBI_NO_PIC +#endif +#ifndef STBI_ONLY_PNM +#define STBI_NO_PNM +#endif #endif +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && \ + !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif +#include #include #include // ptrdiff_t on osx #include #include -#include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow +#include // ldexp, pow #endif #ifndef STBI_NO_STDIO @@ -576,55 +619,55 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #define STBI_EXTERN extern #endif - #ifndef _MSC_VER - #ifdef __cplusplus - #define stbi_inline inline - #else - #define stbi_inline - #endif +#ifdef __cplusplus +#define stbi_inline inline +#else +#define stbi_inline +#endif #else - #define stbi_inline __forceinline +#define stbi_inline __forceinline #endif #ifndef STBI_NO_THREAD_LOCALS - #if defined(__cplusplus) && __cplusplus >= 201103L - #define STBI_THREAD_LOCAL thread_local - #elif defined(__GNUC__) && __GNUC__ < 5 - #define STBI_THREAD_LOCAL __thread - #elif defined(_MSC_VER) - #define STBI_THREAD_LOCAL __declspec(thread) - #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) - #define STBI_THREAD_LOCAL _Thread_local - #endif - - #ifndef STBI_THREAD_LOCAL - #if defined(__GNUC__) - #define STBI_THREAD_LOCAL __thread - #endif - #endif +#if defined(__cplusplus) && __cplusplus >= 201103L +#define STBI_THREAD_LOCAL thread_local +#elif defined(__GNUC__) && __GNUC__ < 5 +#define STBI_THREAD_LOCAL __thread +#elif defined(_MSC_VER) +#define STBI_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && \ + !defined(__STDC_NO_THREADS__) +#define STBI_THREAD_LOCAL _Thread_local +#endif + +#ifndef STBI_THREAD_LOCAL +#if defined(__GNUC__) +#define STBI_THREAD_LOCAL __thread +#endif +#endif #endif #ifdef _MSC_VER typedef unsigned short stbi__uint16; -typedef signed short stbi__int16; -typedef unsigned int stbi__uint32; -typedef signed int stbi__int32; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; #else #include typedef uint16_t stbi__uint16; -typedef int16_t stbi__int16; +typedef int16_t stbi__int16; typedef uint32_t stbi__uint32; -typedef int32_t stbi__int32; +typedef int32_t stbi__int32; #endif // should produce compiler error if size is wrong -typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; +typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; #ifdef _MSC_VER -#define STBI_NOTUSED(v) (void)(v) +#define STBI_NOTUSED(v) (void)(v) #else -#define STBI_NOTUSED(v) (void)sizeof(v) +#define STBI_NOTUSED(v) (void)sizeof(v) #endif #ifdef _MSC_VER @@ -632,27 +675,30 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #endif #ifdef STBI_HAS_LROTL - #define stbi_lrot(x,y) _lrotl(x,y) +#define stbi_lrot(x, y) _lrotl(x, y) #else - #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (32 - (y)))) +#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (32 - (y)))) #endif -#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +#if defined(STBI_MALLOC) && defined(STBI_FREE) && \ + (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) // ok -#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && \ + !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) // ok #else -#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#error \ + "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." #endif #ifndef STBI_MALLOC -#define STBI_MALLOC(sz) malloc(sz) -#define STBI_REALLOC(p,newsz) realloc(p,newsz) -#define STBI_FREE(p) free(p) +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p, newsz) realloc(p, newsz) +#define STBI_FREE(p) free(p) #endif #ifndef STBI_REALLOC_SIZED -#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) #endif // x86/x64 detection @@ -662,7 +708,8 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI__X86_TARGET #endif -#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && \ + !defined(STBI_NO_SIMD) // gcc doesn't support sse2 intrinsics unless you compile with -msse2, // which in turn means it gets to use SSE2 everywhere. This is unfortunate, // but previous attempts to provide the SSE2 functions with runtime @@ -673,8 +720,10 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; #define STBI_NO_SIMD #endif -#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) -// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && \ + !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid +// STBI__X64_TARGET // // 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the // Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. @@ -684,44 +733,43 @@ typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; // See https://github.com/nothings/stb/issues/81 for more information. // // So default to no SSE2 on 32-bit MinGW. If you've read this far and added -// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +// -mstackrealign to your build settings, feel free to #define +// STBI_MINGW_ENABLE_SSE2. #define STBI_NO_SIMD #endif -#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#if !defined(STBI_NO_SIMD) && \ + (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) #define STBI_SSE2 #include #ifdef _MSC_VER -#if _MSC_VER >= 1400 // not VC6 -#include // __cpuid -static int stbi__cpuid3(void) -{ - int info[4]; - __cpuid(info,1); - return info[3]; +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) { + int info[4]; + __cpuid(info, 1); + return info[3]; } #else -static int stbi__cpuid3(void) -{ - int res; - __asm { +static int stbi__cpuid3(void) { + int res; + __asm { mov eax,1 cpuid mov res,edx - } - return res; + } + return res; } #endif #define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - int info3 = stbi__cpuid3(); - return ((info3 >> 26) & 1) != 0; +static int stbi__sse2_available(void) { + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; } #endif @@ -729,12 +777,11 @@ static int stbi__sse2_available(void) #define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) #if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) -static int stbi__sse2_available(void) -{ - // If we're even attempting to compile this on GCC/Clang, that means - // -msse2 is on, which means the compiler is allowed to use SSE2 - // instructions at will, and so are we. - return 1; +static int stbi__sse2_available(void) { + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; } #endif @@ -766,188 +813,182 @@ static int stbi__sse2_available(void) // stbi__context structure is our basic context used by all images, so it // contains all the IO context, plus some basic image information -typedef struct -{ - stbi__uint32 img_x, img_y; - int img_n, img_out_n; +typedef struct { + stbi__uint32 img_x, img_y; + int img_n, img_out_n; - stbi_io_callbacks io; - void *io_user_data; + stbi_io_callbacks io; + void *io_user_data; - int read_from_callbacks; - int buflen; - stbi_uc buffer_start[128]; - int callback_already_read; + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; - stbi_uc *img_buffer, *img_buffer_end; - stbi_uc *img_buffer_original, *img_buffer_original_end; + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; } stbi__context; - static void stbi__refill_buffer(stbi__context *s); // initialize a memory-decode context -static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) -{ - s->io.read = NULL; - s->read_from_callbacks = 0; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; - s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) { + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; } // initialize a callback-based context -static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) -{ - s->io = *c; - s->io_user_data = user; - s->buflen = sizeof(s->buffer_start); - s->read_from_callbacks = 1; - s->callback_already_read = 0; - s->img_buffer = s->img_buffer_original = s->buffer_start; - stbi__refill_buffer(s); - s->img_buffer_original_end = s->img_buffer_end; +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, + void *user) { + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; } #ifndef STBI_NO_STDIO -static int stbi__stdio_read(void *user, char *data, int size) -{ - return (int) fread(data,1,size,(FILE*) user); +static int stbi__stdio_read(void *user, char *data, int size) { + return (int)fread(data, 1, size, (FILE *)user); } -static void stbi__stdio_skip(void *user, int n) -{ - int ch; - fseek((FILE*) user, n, SEEK_CUR); - ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ - if (ch != EOF) { - ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ - } +static void stbi__stdio_skip(void *user, int n) { + int ch; + fseek((FILE *)user, n, SEEK_CUR); + ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ + } } -static int stbi__stdio_eof(void *user) -{ - return feof((FILE*) user) || ferror((FILE *) user); +static int stbi__stdio_eof(void *user) { + return feof((FILE *)user) || ferror((FILE *)user); } -static stbi_io_callbacks stbi__stdio_callbacks = -{ - stbi__stdio_read, - stbi__stdio_skip, - stbi__stdio_eof, +static stbi_io_callbacks stbi__stdio_callbacks = { + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, }; -static void stbi__start_file(stbi__context *s, FILE *f) -{ - stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +static void stbi__start_file(stbi__context *s, FILE *f) { + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } -//static void stop_file(stbi__context *s) { } +// static void stop_file(stbi__context *s) { } #endif // !STBI_NO_STDIO -static void stbi__rewind(stbi__context *s) -{ - // conceptually rewind SHOULD rewind to the beginning of the stream, - // but we just rewind to the beginning of the initial buffer, because - // we only use it after doing 'test', which only ever looks at at most 92 bytes - s->img_buffer = s->img_buffer_original; - s->img_buffer_end = s->img_buffer_original_end; +static void stbi__rewind(stbi__context *s) { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 + // bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; } -enum -{ - STBI_ORDER_RGB, - STBI_ORDER_BGR -}; +enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; -typedef struct -{ - int bits_per_channel; - int num_channels; - int channel_order; +typedef struct { + int bits_per_channel; + int num_channels; + int channel_order; } stbi__result_info; #ifndef STBI_NO_JPEG -static int stbi__jpeg_test(stbi__context *s); -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNG -static int stbi__png_test(stbi__context *s); -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__png_is16(stbi__context *s); +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); #endif #ifndef STBI_NO_BMP -static int stbi__bmp_test(stbi__context *s); -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_TGA -static int stbi__tga_test(stbi__context *s); -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s); -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); -static int stbi__psd_is16(stbi__context *s); +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); #endif #ifndef STBI_NO_HDR -static int stbi__hdr_test(stbi__context *s); -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PIC -static int stbi__pic_test(stbi__context *s); -static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_GIF -static int stbi__gif_test(stbi__context *s); -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); #endif #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s); -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); #endif static #ifdef STBI_THREAD_LOCAL -STBI_THREAD_LOCAL + STBI_THREAD_LOCAL #endif -const char *stbi__g_failure_reason; + const char *stbi__g_failure_reason; -STBIDEF const char *stbi_failure_reason(void) -{ - return stbi__g_failure_reason; +STBIDEF const char *stbi_failure_reason(void) { + return stbi__g_failure_reason; } #ifndef STBI_NO_FAILURE_STRINGS -static int stbi__err(const char *str) -{ - stbi__g_failure_reason = str; - return 0; +static int stbi__err(const char *str) { + stbi__g_failure_reason = str; + return 0; } #endif -static void *stbi__malloc(size_t size) -{ - return STBI_MALLOC(size); +static void *stbi__malloc(size_t size) { + return STBI_MALLOC(size); } // stb_image uses ints pervasively, including for offset calculations. @@ -962,70 +1003,72 @@ static void *stbi__malloc(size_t size) // return 1 if the sum is valid, 0 on overflow. // negative terms are considered invalid. -static int stbi__addsizes_valid(int a, int b) -{ - if (b < 0) return 0; - // now 0 <= b <= INT_MAX, hence also - // 0 <= INT_MAX - b <= INTMAX. - // And "a + b <= INT_MAX" (which might overflow) is the - // same as a <= INT_MAX - b (no overflow) - return a <= INT_MAX - b; +static int stbi__addsizes_valid(int a, int b) { + if (b < 0) + return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; } // returns 1 if the product is valid, 0 on overflow. // negative factors are considered invalid. -static int stbi__mul2sizes_valid(int a, int b) -{ - if (a < 0 || b < 0) return 0; - if (b == 0) return 1; // mul-by-0 is always safe - // portable way to check for no overflows in a*b - return a <= INT_MAX/b; +static int stbi__mul2sizes_valid(int a, int b) { + if (a < 0 || b < 0) + return 0; + if (b == 0) + return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; } -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow -static int stbi__mad2sizes_valid(int a, int b, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +static int stbi__mad2sizes_valid(int a, int b, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); } #endif // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow -static int stbi__mad3sizes_valid(int a, int b, int c, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__addsizes_valid(a*b*c, add); +static int stbi__mad3sizes_valid(int a, int b, int c, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__addsizes_valid(a * b * c, add); } -// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't +// overflow #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) -{ - return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && - stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && + stbi__mul2sizes_valid(a * b * c, d) && + stbi__addsizes_valid(a * b * c * d, add); } #endif -#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || \ + !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) // mallocs with size overflow checking -static void *stbi__malloc_mad2(int a, int b, int add) -{ - if (!stbi__mad2sizes_valid(a, b, add)) return NULL; - return stbi__malloc(a*b + add); +static void *stbi__malloc_mad2(int a, int b, int add) { + if (!stbi__mad2sizes_valid(a, b, add)) + return NULL; + return stbi__malloc(a * b + add); } #endif -static void *stbi__malloc_mad3(int a, int b, int c, int add) -{ - if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; - return stbi__malloc(a*b*c + add); +static void *stbi__malloc_mad3(int a, int b, int c, int add) { + if (!stbi__mad3sizes_valid(a, b, c, add)) + return NULL; + return stbi__malloc(a * b * c + add); } #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) -{ - if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; - return stbi__malloc(a*b*c*d + add); +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) { + if (!stbi__mad4sizes_valid(a, b, c, d, add)) + return NULL; + return stbi__malloc(a * b * c * d + add); } #endif @@ -1034,417 +1077,459 @@ static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) // stbi__errpuc - error returning pointer to unsigned char #ifdef STBI_NO_FAILURE_STRINGS - #define stbi__err(x,y) 0 +#define stbi__err(x, y) 0 #elif defined(STBI_FAILURE_USERMSG) - #define stbi__err(x,y) stbi__err(y) +#define stbi__err(x, y) stbi__err(y) #else - #define stbi__err(x,y) stbi__err(x) +#define stbi__err(x, y) stbi__err(x) #endif -#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) -#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpuc(x, y) \ + ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) -STBIDEF void stbi_image_free(void *retval_from_stbi_load) -{ - STBI_FREE(retval_from_stbi_load); +STBIDEF void stbi_image_free(void *retval_from_stbi_load) { + STBI_FREE(retval_from_stbi_load); } #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); #endif #ifndef STBI_NO_HDR -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); #endif static int stbi__vertically_flip_on_load_global = 0; -STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; } #ifndef STBI_THREAD_LOCAL -#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global #else -static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, + stbi__vertically_flip_on_load_set; -STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) -{ - stbi__vertically_flip_on_load_local = flag_true_if_should_flip; - stbi__vertically_flip_on_load_set = 1; +STBIDEF void +stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; } -#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ - ? stbi__vertically_flip_on_load_local \ - : stbi__vertically_flip_on_load_global) +#define stbi__vertically_flip_on_load \ + (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) #endif // STBI_THREAD_LOCAL -static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields - ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed - ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order - ri->num_channels = 0; - - #ifndef STBI_NO_JPEG - if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNG - if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_BMP - if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_GIF - if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PSD - if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); - #else - STBI_NOTUSED(bpc); - #endif - #ifndef STBI_NO_PIC - if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); - #endif - #ifndef STBI_NO_PNM - if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); - #endif - - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); - return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); - } - #endif - - #ifndef STBI_NO_TGA - // test tga last because it's a crappy test! - if (stbi__tga_test(s)) - return stbi__tga_load(s,x,y,comp,req_comp, ri); - #endif - - return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); -} - -static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi_uc *reduced; - - reduced = (stbi_uc *) stbi__malloc(img_len); - if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); - - for (i = 0; i < img_len; ++i) - reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling - - STBI_FREE(orig); - return reduced; -} - -static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) -{ - int i; - int img_len = w * h * channels; - stbi__uint16 *enlarged; +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = + 8; // default is 8 so most paths don't have to be changed + ri->channel_order = + STBI_ORDER_RGB; // all current input & output are this, but this is here + // so we can add BGR order + ri->num_channels = 0; - enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); - if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); +#ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) + return stbi__jpeg_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNG + if (stbi__png_test(s)) + return stbi__png_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) + return stbi__bmp_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_GIF + if (stbi__gif_test(s)) + return stbi__gif_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PSD + if (stbi__psd_test(s)) + return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); +#else + STBI_NOTUSED(bpc); +#endif +#ifndef STBI_NO_PIC + if (stbi__pic_test(s)) + return stbi__pic_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) + return stbi__pnm_load(s, x, y, comp, req_comp, ri); +#endif - for (i = 0; i < img_len; ++i) - enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } +#endif - STBI_FREE(orig); - return enlarged; -} +#ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s, x, y, comp, req_comp, ri); +#endif -static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) -{ - int row; - size_t bytes_per_row = (size_t)w * bytes_per_pixel; - stbi_uc temp[2048]; - stbi_uc *bytes = (stbi_uc *)image; - - for (row = 0; row < (h>>1); row++) { - stbi_uc *row0 = bytes + row*bytes_per_row; - stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; - // swap row0 with row1 - size_t bytes_left = bytes_per_row; - while (bytes_left) { - size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); - memcpy(temp, row0, bytes_copy); - memcpy(row0, row1, bytes_copy); - memcpy(row1, temp, bytes_copy); - row0 += bytes_copy; - row1 += bytes_copy; - bytes_left -= bytes_copy; - } - } + return stbi__errpuc("unknown image type", + "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *)stbi__malloc(img_len); + if (reduced == NULL) + return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = + (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient + // approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, + int channels) { + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); + if (enlarged == NULL) + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + + orig[i]); // replicate to high and low byte, + // maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, + int bytes_per_pixel) { + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h >> 1); row++) { + stbi_uc *row0 = bytes + row * bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1) * bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = + (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } } #ifndef STBI_NO_GIF -static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) -{ - int slice; - int slice_size = w * h * bytes_per_pixel; +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, + int bytes_per_pixel) { + int slice; + int slice_size = w * h * bytes_per_pixel; - stbi_uc *bytes = (stbi_uc *)image; - for (slice = 0; slice < z; ++slice) { - stbi__vertical_flip(bytes, w, h, bytes_per_pixel); - bytes += slice_size; - } + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } } #endif -static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 8) { - result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 8; - } + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } - // @TODO: move stbi__convert_format to here + // @TODO: move stbi__convert_format to here - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } - return (unsigned char *) result; + return (unsigned char *)result; } -static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - stbi__result_info ri; - void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, + int *y, int *comp, + int req_comp) { + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); - if (result == NULL) - return NULL; + if (result == NULL) + return NULL; - // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. - STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + // it is the responsibility of the loaders to make sure we get either 8 or 16 + // bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); - if (ri.bits_per_channel != 16) { - result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); - ri.bits_per_channel = 16; - } + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, + req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } - // @TODO: move stbi__convert_format16 to here - // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to + // keep more precision - if (stbi__vertically_flip_on_load) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); - } + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } - return (stbi__uint16 *) result; + return (stbi__uint16 *)result; } #if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) -static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) -{ - if (stbi__vertically_flip_on_load && result != NULL) { - int channels = req_comp ? req_comp : *comp; - stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); - } +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, + int req_comp) { + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } } #endif #ifndef STBI_NO_STDIO #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); -STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar( + unsigned int cp, unsigned long flags, const char *str, int cbmb, + wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte( + unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, + char *str, int cbmb, const char *defchar, int *used_default); #endif #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) -STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) -{ - return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, + const wchar_t *input) { + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, + (int)bufferlen, NULL, NULL); } #endif -static FILE *stbi__fopen(char const *filename, char const *mode) -{ - FILE *f; +static FILE *stbi__fopen(char const *filename, char const *mode) { + FILE *f; #if defined(_MSC_VER) && defined(STBI_WINDOWS_UTF8) - wchar_t wMode[64]; - wchar_t wFilename[1024]; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename))) - return 0; + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, + sizeof(wFilename))) + return 0; - if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) - return 0; + if (0 == + MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode))) + return 0; #if _MSC_VER >= 1400 - if (0 != _wfopen_s(&f, wFilename, wMode)) - f = 0; + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; #else - f = _wfopen(wFilename, wMode); + f = _wfopen(wFilename, wMode); #endif #elif defined(_MSC_VER) && _MSC_VER >= 1400 - if (0 != fopen_s(&f, filename, mode)) - f=0; + if (0 != fopen_s(&f, filename, mode)) + f = 0; #else - f = fopen(filename, mode); + f = fopen(filename, mode); #endif - return f; -} - - -STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - unsigned char *result; - if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; -} - -STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__uint16 *result; - stbi__context s; - stbi__start_file(&s,f); - result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); - if (result) { - // need to 'unget' all the characters in the IO buffer - fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); - } - return result; -} - -STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - stbi__uint16 *result; - if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); - result = stbi_load_from_file_16(f,x,y,comp,req_comp); - fclose(f); - return result; -} - - -#endif //!STBI_NO_STDIO - -STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); - return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); -} - -STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); -} - -STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + return f; +} + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) + return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, + int req_comp) { + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) + return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +#endif //! STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, + int *x, int *y, int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, + int *channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, + desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); } #ifndef STBI_NO_GIF -STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - unsigned char *result; - stbi__context s; - stbi__start_mem(&s,buffer,len); +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, + int **delays, int *x, int *y, int *z, + int *comp, int req_comp) { + unsigned char *result; + stbi__context s; + stbi__start_mem(&s, buffer, len); - result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); - if (stbi__vertically_flip_on_load) { - stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); - } + result = + (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices(result, *x, *y, *z, *comp); + } - return result; + return result; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) -{ - unsigned char *data; - #ifndef STBI_NO_HDR - if (stbi__hdr_test(s)) { - stbi__result_info ri; - float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); - if (hdr_data) - stbi__float_postprocess(hdr_data,x,y,comp,req_comp); - return hdr_data; - } - #endif - data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); - if (data) - return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); - return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); -} - -STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__loadf_main(&s,x,y,comp,req_comp); +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, + int req_comp) { + unsigned char *data; +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data, x, y, comp, req_comp); + return hdr_data; + } +#endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", + "Image not of any known type, or corrupt"); } -STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__loadf_main(&s,x,y,comp,req_comp); +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -#ifndef STBI_NO_STDIO -STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) -{ - float *result; - FILE *f = stbi__fopen(filename, "rb"); - if (!f) return stbi__errpf("can't fopen", "Unable to open file"); - result = stbi_loadf_from_file(f,x,y,comp,req_comp); - fclose(f); - return result; +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, + void *user, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__loadf_main(&s, x, y, comp, req_comp); } -STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) -{ - stbi__context s; - stbi__start_file(&s,f); - return stbi__loadf_main(&s,x,y,comp,req_comp); +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, + int req_comp) { + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) + return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, + int req_comp) { + stbi__context s; + stbi__start_file(&s, f); + return stbi__loadf_main(&s, x, y, comp, req_comp); } #endif // !STBI_NO_STDIO @@ -1454,221 +1539,222 @@ STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_ // defined, for API simplicity; if STBI_NO_LINEAR is defined, it always // reports false! -STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(buffer); - STBI_NOTUSED(len); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; +#endif } #ifndef STBI_NO_STDIO -STBIDEF int stbi_is_hdr (char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result=0; - if (f) { - result = stbi_is_hdr_from_file(f); - fclose(f); - } - return result; +STBIDEF int stbi_is_hdr(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result = 0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; } -STBIDEF int stbi_is_hdr_from_file(FILE *f) -{ - #ifndef STBI_NO_HDR - long pos = ftell(f); - int res; - stbi__context s; - stbi__start_file(&s,f); - res = stbi__hdr_test(&s); - fseek(f, pos, SEEK_SET); - return res; - #else - STBI_NOTUSED(f); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_file(FILE *f) { +#ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s, f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; +#else + STBI_NOTUSED(f); + return 0; +#endif } #endif // !STBI_NO_STDIO -STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) -{ - #ifndef STBI_NO_HDR - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); - return stbi__hdr_test(&s); - #else - STBI_NOTUSED(clbk); - STBI_NOTUSED(user); - return 0; - #endif +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, + void *user) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; +#endif } #ifndef STBI_NO_LINEAR -static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; +static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; -STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } -STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { + stbi__l2h_gamma = gamma; +} +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { + stbi__l2h_scale = scale; +} #endif -static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; - -STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } -STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } +static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { + stbi__h2l_gamma_i = 1 / gamma; +} +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { + stbi__h2l_scale_i = 1 / scale; +} ////////////////////////////////////////////////////////////////////////////// // // Common code used by all image loaders // -enum -{ - STBI__SCAN_load=0, - STBI__SCAN_type, - STBI__SCAN_header -}; - -static void stbi__refill_buffer(stbi__context *s) -{ - int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); - s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); - if (n == 0) { - // at end of file, treat same as if from memory, but need to handle case - // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file - s->read_from_callbacks = 0; - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start+1; - *s->img_buffer = 0; - } else { - s->img_buffer = s->buffer_start; - s->img_buffer_end = s->buffer_start + n; - } -} - -stbi_inline static stbi_uc stbi__get8(stbi__context *s) -{ - if (s->img_buffer < s->img_buffer_end) - return *s->img_buffer++; - if (s->read_from_callbacks) { - stbi__refill_buffer(s); - return *s->img_buffer++; - } - return 0; -} - -#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; + +static void stbi__refill_buffer(stbi__context *s) { + int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -stbi_inline static int stbi__at_eof(stbi__context *s) -{ - if (s->io.read) { - if (!(s->io.eof)(s->io_user_data)) return 0; - // if feof() is true, check if buffer = end - // special case: we've only got the special 0 character at the end - if (s->read_from_callbacks == 0) return 1; - } +stbi_inline static int stbi__at_eof(stbi__context *s) { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) + return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) + return 1; + } - return s->img_buffer >= s->img_buffer_end; + return s->img_buffer >= s->img_buffer_end; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) // nothing #else -static void stbi__skip(stbi__context *s, int n) -{ - if (n == 0) return; // already there! - if (n < 0) { +static void stbi__skip(stbi__context *s, int n) { + if (n == 0) + return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); return; - } - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - s->img_buffer = s->img_buffer_end; - (s->io.skip)(s->io_user_data, n - blen); - return; - } - } - s->img_buffer += n; + } + } + s->img_buffer += n; } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && \ + defined(STBI_NO_PNM) // nothing #else -static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) -{ - if (s->io.read) { - int blen = (int) (s->img_buffer_end - s->img_buffer); - if (blen < n) { - int res, count; +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) { + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; - memcpy(buffer, s->img_buffer, blen); + memcpy(buffer, s->img_buffer, blen); - count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); - res = (count == (n-blen)); - s->img_buffer = s->img_buffer_end; - return res; - } - } + count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); + res = (count == (n - blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } - if (s->img_buffer+n <= s->img_buffer_end) { - memcpy(buffer, s->img_buffer, n); - s->img_buffer += n; - return 1; - } else - return 0; + if (s->img_buffer + n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; } #endif -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_PIC) // nothing #else -static int stbi__get16be(stbi__context *s) -{ - int z = stbi__get8(s); - return (z << 8) + stbi__get8(s); +static int stbi__get16be(stbi__context *s) { + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) // nothing #else -static stbi__uint32 stbi__get32be(stbi__context *s) -{ - stbi__uint32 z = stbi__get16be(s); - return (z << 16) + stbi__get16be(s); +static stbi__uint32 stbi__get32be(stbi__context *s) { + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); } #endif #if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) // nothing #else -static int stbi__get16le(stbi__context *s) -{ - int z = stbi__get8(s); - return z + (stbi__get8(s) << 8); +static int stbi__get16le(stbi__context *s) { + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); } #endif #ifndef STBI_NO_BMP -static stbi__uint32 stbi__get32le(stbi__context *s) -{ - stbi__uint32 z = stbi__get16le(s); - return z + (stbi__get16le(s) << 16); +static stbi__uint32 stbi__get32le(stbi__context *s) { + stbi__uint32 z = stbi__get16le(s); + return z + (stbi__get16le(s) << 16); } #endif -#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings +#define STBI__BYTECAST(x) \ + ((stbi_uc)((x) & 255)) // truncate int to byte without warnings -#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && \ + defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) && defined(STBI_NO_PNM) // nothing #else ////////////////////////////////////////////////////////////////////////////// @@ -1682,169 +1768,301 @@ static stbi__uint32 stbi__get32le(stbi__context *s) // assume data buffer is malloced, so malloc a new one and free that one // only failure mode is malloc failing -static stbi_uc stbi__compute_y(int r, int g, int b) -{ - return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi_uc stbi__compute_y(int r, int g, int b) { + return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif -#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && \ + defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && \ + defined(STBI_NO_PNM) // nothing #else -static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - unsigned char *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); - if (good == NULL) { - STBI_FREE(data); - return stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - unsigned char *src = data + j * x * img_n ; - unsigned char *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + unsigned char *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + unsigned char *src = data + j * x * img_n; + unsigned char *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 255; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 255; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 255; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = 255; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 stbi__compute_y_16(int r, int g, int b) -{ - return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { + return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } #endif #if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) // nothing #else -static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) -{ - int i,j; - stbi__uint16 *good; - - if (req_comp == img_n) return data; - STBI_ASSERT(req_comp >= 1 && req_comp <= 4); - - good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); - if (good == NULL) { - STBI_FREE(data); - return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); - } - - for (j=0; j < (int) y; ++j) { - stbi__uint16 *src = data + j * x * img_n ; - stbi__uint16 *dest = good + j * x * req_comp; - - #define STBI__COMBO(a,b) ((a)*8+(b)) - #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) - // convert source image with img_n components to one with req_comp components; - // avoid switch per pixel, so use switch per scanline and massive macros - switch (STBI__COMBO(img_n, req_comp)) { - STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; - STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; - STBI__CASE(2,1) { dest[0]=src[0]; } break; - STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; - STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; - STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; - STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; - STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; - STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; - STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; - default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, + int req_comp, unsigned int x, + unsigned int y) { + int i, j; + stbi__uint16 *good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + stbi__uint16 *src = data + j * x * img_n; + stbi__uint16 *dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a) * 8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp + // components; avoid switch per pixel, so use switch per scanline and + // massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 0xffff; + } + break; + STBI__CASE(1, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 0xffff; + } + break; + STBI__CASE(2, 1) { + dest[0] = src[0]; + } + break; + STBI__CASE(2, 3) { + dest[0] = dest[1] = dest[2] = src[0]; + } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 0xffff; + } + break; + STBI__CASE(3, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } - #undef STBI__CASE - } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = 0xffff; + } + break; + STBI__CASE(4, 1) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return (stbi__uint16 *)stbi__errpuc("unsupported", + "Unsupported format conversion"); + } +#undef STBI__CASE + } - STBI_FREE(data); - return good; + STBI_FREE(data); + return good; } #endif #ifndef STBI_NO_LINEAR -static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) -{ - int i,k,n; - float *output; - if (!data) return NULL; - output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); - } - } - if (n < comp) { - for (i=0; i < x*y; ++i) { - output[i*comp + n] = data[i*comp + n]/255.0f; - } - } - STBI_FREE(data); - return output; +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) { + int i, k, n; + float *output; + if (!data) + return NULL; + output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpf("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + output[i * comp + k] = + (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * + stbi__l2h_scale); + } + } + if (n < comp) { + for (i = 0; i < x * y; ++i) { + output[i * comp + n] = data[i * comp + n] / 255.0f; + } + } + STBI_FREE(data); + return output; } #endif #ifndef STBI_NO_HDR -#define stbi__float2int(x) ((int) (x)) -static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) -{ - int i,k,n; - stbi_uc *output; - if (!data) return NULL; - output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); - if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } - // compute number of non-alpha components - if (comp & 1) n = comp; else n = comp-1; - for (i=0; i < x*y; ++i) { - for (k=0; k < n; ++k) { - float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - if (k < comp) { - float z = data[i*comp+k] * 255 + 0.5f; - if (z < 0) z = 0; - if (z > 255) z = 255; - output[i*comp + k] = (stbi_uc) stbi__float2int(z); - } - } - STBI_FREE(data); - return output; +#define stbi__float2int(x) ((int)(x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) { + int i, k, n; + stbi_uc *output; + if (!data) + return NULL; + output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, + stbi__h2l_gamma_i) * + 255 + + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + if (k < comp) { + float z = data[i * comp + k] * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + } + STBI_FREE(data); + return output; } #endif @@ -1872,750 +2090,791 @@ static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) #ifndef STBI_NO_JPEG // huffman decoding acceleration -#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache - -typedef struct -{ - stbi_uc fast[1 << FAST_BITS]; - // weirdly, repacking this into AoS is a 10% speed loss, instead of a win - stbi__uint16 code[256]; - stbi_uc values[256]; - stbi_uc size[257]; - unsigned int maxcode[18]; - int delta[17]; // old 'firstsymbol' - old 'firstcode' +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct { + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' } stbi__huffman; -typedef struct -{ - stbi__context *s; - stbi__huffman huff_dc[4]; - stbi__huffman huff_ac[4]; - stbi__uint16 dequant[4][64]; - stbi__int16 fast_ac[4][1 << FAST_BITS]; - -// sizes for components, interleaved MCUs - int img_h_max, img_v_max; - int img_mcu_x, img_mcu_y; - int img_mcu_w, img_mcu_h; - -// definition of jpeg image component - struct - { - int id; - int h,v; - int tq; - int hd,ha; - int dc_pred; - - int x,y,w2,h2; - stbi_uc *data; - void *raw_data, *raw_coeff; - stbi_uc *linebuf; - short *coeff; // progressive only - int coeff_w, coeff_h; // number of 8x8 coefficient blocks - } img_comp[4]; - - stbi__uint32 code_buffer; // jpeg entropy-coded buffer - int code_bits; // number of valid bits - unsigned char marker; // marker seen while filling entropy buffer - int nomore; // flag if we saw a marker so must stop - - int progressive; - int spec_start; - int spec_end; - int succ_high; - int succ_low; - int eob_run; - int jfif; - int app14_color_transform; // Adobe APP14 tag - int rgb; - - int scan_n, order[4]; - int restart_interval, todo; - -// kernels - void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); - void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); - stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +typedef struct { + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, int count, + int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs); } stbi__jpeg; -static int stbi__build_huffman(stbi__huffman *h, int *count) -{ - int i,j,k=0; - unsigned int code; - // build size list for each symbol (from JPEG spec) - for (i=0; i < 16; ++i) - for (j=0; j < count[i]; ++j) - h->size[k++] = (stbi_uc) (i+1); - h->size[k] = 0; - - // compute actual symbols (from jpeg spec) - code = 0; - k = 0; - for(j=1; j <= 16; ++j) { - // compute delta to add to code to compute symbol id - h->delta[j] = k - code; - if (h->size[k] == j) { - while (h->size[k] == j) - h->code[k++] = (stbi__uint16) (code++); - if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); - } - // compute largest code + 1 for this size, preshifted as needed later - h->maxcode[j] = code << (16-j); - code <<= 1; - } - h->maxcode[j] = 0xffffffff; - - // build non-spec acceleration table; 255 is flag for not-accelerated - memset(h->fast, 255, 1 << FAST_BITS); - for (i=0; i < k; ++i) { - int s = h->size[i]; - if (s <= FAST_BITS) { - int c = h->code[i] << (FAST_BITS-s); - int m = 1 << (FAST_BITS-s); - for (j=0; j < m; ++j) { - h->fast[c+j] = (stbi_uc) i; - } +static int stbi__build_huffman(stbi__huffman *h, int *count) { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) + for (j = 0; j < count[i]; ++j) + h->size[k++] = (stbi_uc)(i + 1); + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16)(code++); + if (code - 1 >= (1u << j)) + return stbi__err("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (stbi_uc)i; } - } - return 1; + } + } + return 1; } // build a table that decodes both magnitude and value of small ACs in // one go. -static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) -{ - int i; - for (i=0; i < (1 << FAST_BITS); ++i) { - stbi_uc fast = h->fast[i]; - fast_ac[i] = 0; - if (fast < 255) { - int rs = h->values[fast]; - int run = (rs >> 4) & 15; - int magbits = rs & 15; - int len = h->size[fast]; - - if (magbits && len + magbits <= FAST_BITS) { - // magnitude code followed by receive_extend code - int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); - int m = 1 << (magbits - 1); - if (k < m) k += (~0U << magbits) + 1; - // if the result is small enough, we can fit it in fast_ac table - if (k >= -128 && k <= 127) - fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); - } +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) + k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); } - } -} - -static void stbi__grow_buffer_unsafe(stbi__jpeg *j) -{ - do { - unsigned int b = j->nomore ? 0 : stbi__get8(j->s); - if (b == 0xff) { - int c = stbi__get8(j->s); - while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes - if (c != 0) { - j->marker = (unsigned char) c; - j->nomore = 1; - return; - } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) { + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) + c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; } - j->code_buffer |= b << (24 - j->code_bits); - j->code_bits += 8; - } while (j->code_bits <= 24); + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); } // (1 << n) - 1 -static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; +static const stbi__uint32 stbi__bmask[17] = { + 0, 1, 3, 7, 15, 31, 63, 127, 255, + 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; // decode a jpeg huffman value from the bitstream -stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) -{ - unsigned int temp; - int c,k; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - // look at the top FAST_BITS and determine what symbol ID it is, - // if the code is <= FAST_BITS - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - k = h->fast[c]; - if (k < 255) { - int s = h->size[k]; - if (s > j->code_bits) - return -1; - j->code_buffer <<= s; - j->code_bits -= s; - return h->values[k]; - } - - // naive test is to shift the code_buffer down so k bits are - // valid, then test against maxcode. To speed this up, we've - // preshifted maxcode left so that it has (16-k) 0s at the - // end; in other words, regardless of the number of bits, it - // wants to be compared against something shifted to have 16; - // that way we don't need to shift inside the loop. - temp = j->code_buffer >> 16; - for (k=FAST_BITS+1 ; ; ++k) - if (temp < h->maxcode[k]) - break; - if (k == 17) { - // error! code not found - j->code_bits -= 16; - return -1; - } - - if (k > j->code_bits) +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) return -1; - - // convert the huffman code to the symbol id - c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; - STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); - - // convert the id to a symbol - j->code_bits -= k; - j->code_buffer <<= k; - return h->values[c]; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1;; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & + stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; } // bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); - - sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB - k = stbi_lrot(j->code_buffer, n); - if (n < 0 || n >= (int) (sizeof(stbi__bmask)/sizeof(*stbi__bmask))) return 0; - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k + (stbi__jbias[n] & ~sgn); +stbi_inline static int stbi__extend_receive(stbi__jpeg *j, int n) { + unsigned int k; + int sgn; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + + sgn = (stbi__int32)j->code_buffer >> 31; // sign bit is always in MSB + k = stbi_lrot(j->code_buffer, n); + if (n < 0 || n >= (int)(sizeof(stbi__bmask) / sizeof(*stbi__bmask))) + return 0; + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & ~sgn); } // get some unsigned bits -stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) -{ - unsigned int k; - if (j->code_bits < n) stbi__grow_buffer_unsafe(j); - k = stbi_lrot(j->code_buffer, n); - j->code_buffer = k & ~stbi__bmask[n]; - k &= stbi__bmask[n]; - j->code_bits -= n; - return k; -} - -stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) -{ - unsigned int k; - if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); - k = j->code_buffer; - j->code_buffer <<= 1; - --j->code_bits; - return k & 0x80000000; +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) { + unsigned int k; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) { + unsigned int k; + if (j->code_bits < 1) + stbi__grow_buffer_unsafe(j); + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; } // given a value that's at position X in the zigzag stream, // where does it appear in the 8x8 matrix coded as row-major? -static const stbi_uc stbi__jpeg_dezigzag[64+15] = -{ - 0, 1, 8, 16, 9, 2, 3, 10, - 17, 24, 32, 25, 18, 11, 4, 5, - 12, 19, 26, 33, 40, 48, 41, 34, - 27, 20, 13, 6, 7, 14, 21, 28, - 35, 42, 49, 56, 57, 50, 43, 36, - 29, 22, 15, 23, 30, 37, 44, 51, - 58, 59, 52, 45, 38, 31, 39, 46, - 53, 60, 61, 54, 47, 55, 62, 63, - // let corrupt input sample past end - 63, 63, 63, 63, 63, 63, 63, 63, - 63, 63, 63, 63, 63, 63, 63 -}; +static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { + 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, + 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, + 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; // decode one 64-entry block-- -static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) -{ - int diff,dc,k; - int t; - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - t = stbi__jpeg_huff_decode(j, hdc); - if (t < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - - // 0 all the ac values now so we can do it 32-bits at a time - memset(data,0,64*sizeof(data[0])); - - diff = t ? stbi__extend_receive(j, t) : 0; - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc * dequant[0]); - - // decode AC components, see JPEG spec - k = 1; - do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) * dequant[zig]); +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, stbi__huffman *hac, + stbi__int16 *fac, int b, + stbi__uint16 *dequant) { + int diff, dc, k; + int t; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) + break; // end block + k += 16; } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (rs != 0xf0) break; // end block - k += 16; - } else { - k += r; - // decode into unzigzag'd location - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); - } + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); } - } while (k < 64); - return 1; -} - -static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) -{ - int diff,dc; - int t; - if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - - if (j->succ_high == 0) { - // first scan for DC coefficient, must be first - memset(data,0,64*sizeof(data[0])); // 0 all the ac values now - t = stbi__jpeg_huff_decode(j, hdc); - if (t == -1) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - diff = t ? stbi__extend_receive(j, t) : 0; - - dc = j->img_comp[b].dc_pred + diff; - j->img_comp[b].dc_pred = dc; - data[0] = (short) (dc << j->succ_low); - } else { - // refinement scan for DC coefficient - if (stbi__jpeg_get_bit(j)) - data[0] += (short) (1 << j->succ_low); - } - return 1; + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], + stbi__huffman *hdc, int b) { + int diff, dc; + int t; + if (j->spec_end != 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t == -1) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc << j->succ_low); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; } // @OPTIMIZE: store non-zigzagged during the decode passes, // and only de-zigzag when dequantizing -static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) -{ - int k; - if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); - - if (j->succ_high == 0) { - int shift = j->succ_low; +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], + stbi__huffman *hac, + stbi__int16 *fac) { + int k; + if (j->spec_start == 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } - if (j->eob_run) { - --j->eob_run; - return 1; + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) << shift); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) << shift); + } } - + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { k = j->spec_start; do { - unsigned int zig; - int c,r,s; - if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); - c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); - r = fac[c]; - if (r) { // fast-AC path - k += (r >> 4) & 15; // run - s = r & 15; // combined length - j->code_buffer <<= s; - j->code_bits -= s; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) ((r >> 8) << shift); - } else { - int rs = stbi__jpeg_huff_decode(j, hac); - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r); - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - --j->eob_run; - break; - } - k += 16; - } else { - k += r; - zig = stbi__jpeg_dezigzag[k++]; - data[zig] = (short) (stbi__extend_receive(j,s) << shift); - } - } - } while (k <= j->spec_end); - } else { - // refinement scan for these AC coefficients - - short bit = (short) (1 << j->succ_low); - - if (j->eob_run) { - --j->eob_run; - for (k = j->spec_start; k <= j->spec_end; ++k) { - short *p = &data[stbi__jpeg_dezigzag[k]]; - if (*p != 0) - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } - } else { - k = j->spec_start; - do { - int r,s; - int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); - s = rs & 15; - r = rs >> 4; - if (s == 0) { - if (r < 15) { - j->eob_run = (1 << r) - 1; - if (r) - j->eob_run += stbi__jpeg_get_bits(j, r); - r = 64; // force end of block - } else { - // r=15 s=0 should write 16 0s, so we just do - // a run of 15 0s and then write s (which is 0), - // so we don't have to do anything special here - } - } else { - if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); - // sign bit - if (stbi__jpeg_get_bit(j)) - s = bit; - else - s = -bit; - } + int r, s; + int rs = stbi__jpeg_huff_decode( + j, hac); // @OPTIMIZE see if we can use the fast path here, + // advance-by-r is so slow, eh + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) + return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } - // advance by r - while (k <= j->spec_end) { - short *p = &data[stbi__jpeg_dezigzag[k++]]; - if (*p != 0) { - if (stbi__jpeg_get_bit(j)) - if ((*p & bit)==0) { - if (*p > 0) - *p += bit; - else - *p -= bit; - } - } else { - if (r == 0) { - *p = (short) s; - break; - } - --r; - } + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short)s; + break; } - } while (k <= j->spec_end); - } - } - return 1; + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; } // take a -128..127 value and stbi__clamp it and convert to 0..255 -stbi_inline static stbi_uc stbi__clamp(int x) -{ - // trick to use a single test to catch both cases - if ((unsigned int) x > 255) { - if (x < 0) return 0; - if (x > 255) return 255; - } - return (stbi_uc) x; +stbi_inline static stbi_uc stbi__clamp(int x) { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) + return 0; + if (x > 255) + return 255; + } + return (stbi_uc)x; } -#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) -#define stbi__fsh(x) ((x) * 4096) +#define stbi__f2f(x) ((int)(((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) // derived from jidctint -- DCT_ISLOW -#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ - int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ - p2 = s2; \ - p3 = s6; \ - p1 = (p2+p3) * stbi__f2f(0.5411961f); \ - t2 = p1 + p3*stbi__f2f(-1.847759065f); \ - t3 = p1 + p2*stbi__f2f( 0.765366865f); \ - p2 = s0; \ - p3 = s4; \ - t0 = stbi__fsh(p2+p3); \ - t1 = stbi__fsh(p2-p3); \ - x0 = t0+t3; \ - x3 = t0-t3; \ - x1 = t1+t2; \ - x2 = t1-t2; \ - t0 = s7; \ - t1 = s5; \ - t2 = s3; \ - t3 = s1; \ - p3 = t0+t2; \ - p4 = t1+t3; \ - p1 = t0+t3; \ - p2 = t1+t2; \ - p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ - t0 = t0*stbi__f2f( 0.298631336f); \ - t1 = t1*stbi__f2f( 2.053119869f); \ - t2 = t2*stbi__f2f( 3.072711026f); \ - t3 = t3*stbi__f2f( 1.501321110f); \ - p1 = p5 + p1*stbi__f2f(-0.899976223f); \ - p2 = p5 + p2*stbi__f2f(-2.562915447f); \ - p3 = p3*stbi__f2f(-1.961570560f); \ - p4 = p4*stbi__f2f(-0.390180644f); \ - t3 += p1+p4; \ - t2 += p2+p3; \ - t1 += p2+p4; \ - t0 += p1+p3; - -static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) -{ - int i,val[64],*v=val; - stbi_uc *o; - short *d = data; - - // columns - for (i=0; i < 8; ++i,++d, ++v) { - // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing - if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 - && d[40]==0 && d[48]==0 && d[56]==0) { - // no shortcut 0 seconds - // (1|2|3|4|5|6|7)==0 0 seconds - // all separate -0.047 seconds - // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds - int dcterm = d[0]*4; - v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; - } else { - STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) - // constants scaled things up by 1<<12; let's bring them back - // down, but keep 2 extra bits of precision - x0 += 512; x1 += 512; x2 += 512; x3 += 512; - v[ 0] = (x0+t3) >> 10; - v[56] = (x0-t3) >> 10; - v[ 8] = (x1+t2) >> 10; - v[48] = (x1-t2) >> 10; - v[16] = (x2+t1) >> 10; - v[40] = (x2-t1) >> 10; - v[24] = (x3+t0) >> 10; - v[32] = (x3-t0) >> 10; - } - } - - for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { - // no fast case since the first 1D IDCT spread components out - STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) - // constants scaled things up by 1<<12, plus we had 1<<2 from first - // loop, plus horizontal and vertical each scale by sqrt(8) so together - // we've got an extra 1<<3, so 1<<17 total we need to remove. - // so we want to round that, which means adding 0.5 * 1<<17, - // aka 65536. Also, we'll end up with -128 to 127 that we want - // to encode as 0..255 by adding 128, so we'll add that before the shift - x0 += 65536 + (128<<17); - x1 += 65536 + (128<<17); - x2 += 65536 + (128<<17); - x3 += 65536 + (128<<17); - // tried computing the shifts into temps, or'ing the temps to see - // if any were out of range, but that was slower - o[0] = stbi__clamp((x0+t3) >> 17); - o[7] = stbi__clamp((x0-t3) >> 17); - o[1] = stbi__clamp((x1+t2) >> 17); - o[6] = stbi__clamp((x1-t2) >> 17); - o[2] = stbi__clamp((x2+t1) >> 17); - o[5] = stbi__clamp((x2-t1) >> 17); - o[3] = stbi__clamp((x3+t0) >> 17); - o[4] = stbi__clamp((x3-t0) >> 17); - } +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ + t0 += p1 + p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) { + int i, val[64], *v = val; + stbi_uc *o; + short *d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && + d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; + x1 += 512; + x2 += 512; + x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0 + t3) >> 17); + o[7] = stbi__clamp((x0 - t3) >> 17); + o[1] = stbi__clamp((x1 + t2) >> 17); + o[6] = stbi__clamp((x1 - t2) >> 17); + o[2] = stbi__clamp((x2 + t1) >> 17); + o[5] = stbi__clamp((x2 - t1) >> 17); + o[3] = stbi__clamp((x3 + t0) >> 17); + o[4] = stbi__clamp((x3 - t0) >> 17); + } } #ifdef STBI_SSE2 // sse2 integer IDCT. not the fastest possible implementation but it // produces bit-identical results to the generic C version so it's // fully "transparent". -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - // This is constructed to match our regular (generic) integer IDCT exactly. - __m128i row0, row1, row2, row3, row4, row5, row6, row7; - __m128i tmp; - - // dot product constant: even elems=x, odd elems=y - #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) - - // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) - // out(1) = c1[even]*x + c1[odd]*y - #define dct_rot(out0,out1, x,y,c0,c1) \ - __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ - __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ - __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ - __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ - __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ - __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) - - // out = in << 12 (in 16-bit, out 32-bit) - #define dct_widen(out, in) \ - __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ - __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - - // wide add - #define dct_wadd(out, a, b) \ - __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - - // wide sub - #define dct_wsub(out, a, b) \ - __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ - __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) - - // butterfly a/b, add bias, then shift by "s" and pack - #define dct_bfly32o(out0, out1, a,b,bias,s) \ - { \ - __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ - __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ - dct_wadd(sum, abiased, b); \ - dct_wsub(dif, abiased, b); \ - out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ - out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ - } - - // 8-bit interleave step (for transposes) - #define dct_interleave8(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi8(a, b); \ - b = _mm_unpackhi_epi8(tmp, b) - - // 16-bit interleave step (for transposes) - #define dct_interleave16(a, b) \ - tmp = a; \ - a = _mm_unpacklo_epi16(a, b); \ - b = _mm_unpackhi_epi16(tmp, b) - - #define dct_pass(bias,shift) \ - { \ - /* even part */ \ - dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ - __m128i sum04 = _mm_add_epi16(row0, row4); \ - __m128i dif04 = _mm_sub_epi16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ - dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ - __m128i sum17 = _mm_add_epi16(row1, row7); \ - __m128i sum35 = _mm_add_epi16(row3, row5); \ - dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ - dct_wadd(x4, y0o, y4o); \ - dct_wadd(x5, y1o, y5o); \ - dct_wadd(x6, y2o, y5o); \ - dct_wadd(x7, y3o, y4o); \ - dct_bfly32o(row0,row7, x0,x7,bias,shift); \ - dct_bfly32o(row1,row6, x1,x6,bias,shift); \ - dct_bfly32o(row2,row5, x2,x5,bias,shift); \ - dct_bfly32o(row3,row4, x3,x4,bias,shift); \ - } +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + +// dot product constant: even elems=x, odd elems=y +#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + +// out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = \ + _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = \ + _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) - __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); - __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); - __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); - __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); - __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); - __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); - __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); - __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); - - // rounding biases in column/row passes, see stbi__idct_block for explanation. - __m128i bias_0 = _mm_set1_epi32(512); - __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); - - // load - row0 = _mm_load_si128((const __m128i *) (data + 0*8)); - row1 = _mm_load_si128((const __m128i *) (data + 1*8)); - row2 = _mm_load_si128((const __m128i *) (data + 2*8)); - row3 = _mm_load_si128((const __m128i *) (data + 3*8)); - row4 = _mm_load_si128((const __m128i *) (data + 4*8)); - row5 = _mm_load_si128((const __m128i *) (data + 5*8)); - row6 = _mm_load_si128((const __m128i *) (data + 6*8)); - row7 = _mm_load_si128((const __m128i *) (data + 7*8)); - - // column pass - dct_pass(bias_0, 10); - - { - // 16bit 8x8 transpose pass 1 - dct_interleave16(row0, row4); - dct_interleave16(row1, row5); - dct_interleave16(row2, row6); - dct_interleave16(row3, row7); - - // transpose pass 2 - dct_interleave16(row0, row2); - dct_interleave16(row1, row3); - dct_interleave16(row4, row6); - dct_interleave16(row5, row7); - - // transpose pass 3 - dct_interleave16(row0, row1); - dct_interleave16(row2, row3); - dct_interleave16(row4, row5); - dct_interleave16(row6, row7); - } - - // row pass - dct_pass(bias_1, 17); - - { - // pack - __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 - __m128i p1 = _mm_packus_epi16(row2, row3); - __m128i p2 = _mm_packus_epi16(row4, row5); - __m128i p3 = _mm_packus_epi16(row6, row7); - - // 8bit 8x8 transpose pass 1 - dct_interleave8(p0, p2); // a0e0a1e1... - dct_interleave8(p1, p3); // c0g0c1g1... - - // transpose pass 2 - dct_interleave8(p0, p1); // a0c0e0g0... - dct_interleave8(p2, p3); // b0d0f0h0... - - // transpose pass 3 - dct_interleave8(p0, p2); // a0b0c0d0... - dct_interleave8(p1, p3); // a4b4c4d4... +// wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) - // store - _mm_storel_epi64((__m128i *) out, p0); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p2); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p1); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; - _mm_storel_epi64((__m128i *) out, p3); out += out_stride; - _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); - } +// wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + +// butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = \ + _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = \ + _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + +// 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + +// 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), + stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), + stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), + stbi__f2f(1.175875602f)); + __m128i rot1_1 = + dct_const(stbi__f2f(1.175875602f), + stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), + stbi__f2f(-1.961570560f)); + __m128i rot2_1 = + dct_const(stbi__f2f(-1.961570560f), + stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), + stbi__f2f(-0.390180644f)); + __m128i rot3_1 = + dct_const(stbi__f2f(-0.390180644f), + stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); + row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); + row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); + row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); + row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); + row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); + row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); + row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *)out, p0); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p2); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p1); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p3); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); + } #undef dct_const #undef dct_rot @@ -2634,198 +2893,236 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) // NEON integer IDCT. should produce bit-identical // results to the generic C version. -static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) -{ - int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; - - int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); - int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); - int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); - int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); - int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); - int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); - int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); - int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); - int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); - int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); - int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); - int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); - -#define dct_long_mul(out, inq, coeff) \ - int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) - -#define dct_long_mac(out, acc, inq, coeff) \ - int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ - int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) - -#define dct_widen(out, inq) \ - int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ - int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) { + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) // wide add -#define dct_wadd(out, a, b) \ - int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vaddq_s32(a##_h, b##_h) +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) // wide sub -#define dct_wsub(out, a, b) \ - int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ - int32x4_t out##_h = vsubq_s32(a##_h, b##_h) +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) // butterfly a/b, then shift using "shiftop" by "s" and pack -#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ - { \ - dct_wadd(sum, a, b); \ - dct_wsub(dif, a, b); \ - out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ - out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ - } - -#define dct_pass(shiftop, shift) \ - { \ - /* even part */ \ - int16x8_t sum26 = vaddq_s16(row2, row6); \ - dct_long_mul(p1e, sum26, rot0_0); \ - dct_long_mac(t2e, p1e, row6, rot0_1); \ - dct_long_mac(t3e, p1e, row2, rot0_2); \ - int16x8_t sum04 = vaddq_s16(row0, row4); \ - int16x8_t dif04 = vsubq_s16(row0, row4); \ - dct_widen(t0e, sum04); \ - dct_widen(t1e, dif04); \ - dct_wadd(x0, t0e, t3e); \ - dct_wsub(x3, t0e, t3e); \ - dct_wadd(x1, t1e, t2e); \ - dct_wsub(x2, t1e, t2e); \ - /* odd part */ \ - int16x8_t sum15 = vaddq_s16(row1, row5); \ - int16x8_t sum17 = vaddq_s16(row1, row7); \ - int16x8_t sum35 = vaddq_s16(row3, row5); \ - int16x8_t sum37 = vaddq_s16(row3, row7); \ - int16x8_t sumodd = vaddq_s16(sum17, sum35); \ - dct_long_mul(p5o, sumodd, rot1_0); \ - dct_long_mac(p1o, p5o, sum17, rot1_1); \ - dct_long_mac(p2o, p5o, sum35, rot1_2); \ - dct_long_mul(p3o, sum37, rot2_0); \ - dct_long_mul(p4o, sum15, rot2_1); \ - dct_wadd(sump13o, p1o, p3o); \ - dct_wadd(sump24o, p2o, p4o); \ - dct_wadd(sump23o, p2o, p3o); \ - dct_wadd(sump14o, p1o, p4o); \ - dct_long_mac(x4, sump13o, row7, rot3_0); \ - dct_long_mac(x5, sump24o, row5, rot3_1); \ - dct_long_mac(x6, sump23o, row3, rot3_2); \ - dct_long_mac(x7, sump14o, row1, rot3_3); \ - dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ - dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ - dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ - dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ - } - - // load - row0 = vld1q_s16(data + 0*8); - row1 = vld1q_s16(data + 1*8); - row2 = vld1q_s16(data + 2*8); - row3 = vld1q_s16(data + 3*8); - row4 = vld1q_s16(data + 4*8); - row5 = vld1q_s16(data + 5*8); - row6 = vld1q_s16(data + 6*8); - row7 = vld1q_s16(data + 7*8); - - // add DC bias - row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); - - // column pass - dct_pass(vrshrn_n_s32, 10); - - // 16bit 8x8 transpose - { +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ + } + + // load + row0 = vld1q_s16(data + 0 * 8); + row1 = vld1q_s16(data + 1 * 8); + row2 = vld1q_s16(data + 2 * 8); + row3 = vld1q_s16(data + 3 * 8); + row4 = vld1q_s16(data + 4 * 8); + row5 = vld1q_s16(data + 5 * 8); + row6 = vld1q_s16(data + 6 * 8); + row7 = vld1q_s16(data + 7 * 8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { // these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. // whether compilers actually get this is another story, sadly. -#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } -#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } - - // pass 1 - dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 - dct_trn16(row2, row3); - dct_trn16(row4, row5); - dct_trn16(row6, row7); - - // pass 2 - dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 - dct_trn32(row1, row3); - dct_trn32(row4, row6); - dct_trn32(row5, row7); - - // pass 3 - dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 - dct_trn64(row1, row5); - dct_trn64(row2, row6); - dct_trn64(row3, row7); +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = \ + vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ + } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); #undef dct_trn16 #undef dct_trn32 #undef dct_trn64 - } - - // row pass - // vrshrn_n_s32 only supports shifts up to 16, we need - // 17. so do a non-rounding shift of 16 first then follow - // up with a rounding shift by 1. - dct_pass(vshrn_n_s32, 16); - - { - // pack and round - uint8x8_t p0 = vqrshrun_n_s16(row0, 1); - uint8x8_t p1 = vqrshrun_n_s16(row1, 1); - uint8x8_t p2 = vqrshrun_n_s16(row2, 1); - uint8x8_t p3 = vqrshrun_n_s16(row3, 1); - uint8x8_t p4 = vqrshrun_n_s16(row4, 1); - uint8x8_t p5 = vqrshrun_n_s16(row5, 1); - uint8x8_t p6 = vqrshrun_n_s16(row6, 1); - uint8x8_t p7 = vqrshrun_n_s16(row7, 1); - - // again, these can translate into one instruction, but often don't. -#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } -#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } -#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } - - // sadly can't use interleaved stores here since we only write - // 8 bytes to each scan line! - - // 8x8 8-bit transpose pass 1 - dct_trn8_8(p0, p1); - dct_trn8_8(p2, p3); - dct_trn8_8(p4, p5); - dct_trn8_8(p6, p7); - - // pass 2 - dct_trn8_16(p0, p2); - dct_trn8_16(p1, p3); - dct_trn8_16(p4, p6); - dct_trn8_16(p5, p7); - - // pass 3 - dct_trn8_32(p0, p4); - dct_trn8_32(p1, p5); - dct_trn8_32(p2, p6); - dct_trn8_32(p3, p7); - - // store - vst1_u8(out, p0); out += out_stride; - vst1_u8(out, p1); out += out_stride; - vst1_u8(out, p2); out += out_stride; - vst1_u8(out, p3); out += out_stride; - vst1_u8(out, p4); out += out_stride; - vst1_u8(out, p5); out += out_stride; - vst1_u8(out, p6); out += out_stride; - vst1_u8(out, p7); + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ + } +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ + } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); + out += out_stride; + vst1_u8(out, p1); + out += out_stride; + vst1_u8(out, p2); + out += out_stride; + vst1_u8(out, p3); + out += out_stride; + vst1_u8(out, p4); + out += out_stride; + vst1_u8(out, p5); + out += out_stride; + vst1_u8(out, p6); + out += out_stride; + vst1_u8(out, p7); #undef dct_trn8_8 #undef dct_trn8_16 #undef dct_trn8_32 - } + } #undef dct_long_mul #undef dct_long_mac @@ -2838,1132 +3135,1274 @@ static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) #endif // STBI_NEON -#define STBI__MARKER_none 0xff +#define STBI__MARKER_none 0xff // if there's a pending marker from the entropy stream, return that // otherwise, fetch from the stream and get a marker. if there's no // marker, return 0xff, which is never a valid marker value -static stbi_uc stbi__get_marker(stbi__jpeg *j) -{ - stbi_uc x; - if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } - x = stbi__get8(j->s); - if (x != 0xff) return STBI__MARKER_none; - while (x == 0xff) - x = stbi__get8(j->s); // consume repeated 0xff fill bytes - return x; +static stbi_uc stbi__get_marker(stbi__jpeg *j) { + stbi_uc x; + if (j->marker != STBI__MARKER_none) { + x = j->marker; + j->marker = STBI__MARKER_none; + return x; + } + x = stbi__get8(j->s); + if (x != 0xff) + return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; } // in each scan, we'll have scan_n components, and the order // of the components is specified by order[] -#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) // after a restart interval, stbi__jpeg_reset the entropy decoder and // the dc prediction -static void stbi__jpeg_reset(stbi__jpeg *j) -{ - j->code_bits = 0; - j->code_buffer = 0; - j->nomore = 0; - j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; - j->marker = STBI__MARKER_none; - j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; - j->eob_run = 0; - // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, - // since we don't even allow 1<<30 pixels -} - -static int stbi__parse_entropy_coded_data(stbi__jpeg *z) -{ - stbi__jpeg_reset(z); - if (!z->progressive) { - if (z->scan_n == 1) { - int i,j; - STBI_SIMD_ALIGN(short, data[64]); - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - // if it's NOT a restart, then just bail, so we get corrupt data - // rather than no data - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - STBI_SIMD_ALIGN(short, data[64]); - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x)*8; - int y2 = (j*z->img_comp[n].v + y)*8; - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; +static void stbi__jpeg_reset(stbi__jpeg *j) { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = + j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) { + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } else { - if (z->scan_n == 1) { - int i,j; - int n = z->order[0]; - // non-interleaved data, we just need to process one block at a time, - // in trivial scanline order - // number of blocks to do just depends on how many actual "pixels" this - // component has, independent of interleaved MCU blocking and such - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - if (z->spec_start == 0) { - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } else { - int ha = z->img_comp[n].ha; - if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) - return 0; - } - // every data block is an MCU, so countdown the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } - } - } - return 1; - } else { // interleaved - int i,j,k,x,y; - for (j=0; j < z->img_mcu_y; ++j) { - for (i=0; i < z->img_mcu_x; ++i) { - // scan an interleaved mcu... process scan_n components in order - for (k=0; k < z->scan_n; ++k) { - int n = z->order[k]; - // scan out an mcu's worth of this component; that's just determined - // by the basic H and V specified for the component - for (y=0; y < z->img_comp[n].v; ++y) { - for (x=0; x < z->img_comp[n].h; ++x) { - int x2 = (i*z->img_comp[n].h + x); - int y2 = (j*z->img_comp[n].v + y); - short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); - if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) - return 0; - } - } - } - // after all interleaved components, that's an interleaved MCU, - // so now count down the restart interval - if (--z->todo <= 0) { - if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); - if (!STBI__RESTART(z->marker)) return 1; - stbi__jpeg_reset(z); - } + return 1; + } else { // interleaved + int i, j, k, x, y; + STBI_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, + z->huff_dc + z->img_comp[n].hd, + z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + + z->img_comp[n].w2 * y2 + x2, + z->img_comp[n].w2, data); + } } - } - return 1; + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) -{ - int i; - for (i=0; i < 64; ++i) - data[i] *= dequant[i]; -} - -static void stbi__jpeg_finish(stbi__jpeg *z) -{ - if (z->progressive) { - // dequantize and idct the data - int i,j,n; - for (n=0; n < z->s->img_n; ++n) { - int w = (z->img_comp[n].x+7) >> 3; - int h = (z->img_comp[n].y+7) >> 3; - for (j=0; j < h; ++j) { - for (i=0; i < w; ++i) { - short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); - stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); - z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); - } - } + return 1; + } + } else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], + z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - } -} - -static int stbi__process_marker(stbi__jpeg *z, int m) -{ - int L; - switch (m) { - case STBI__MARKER_none: // no marker found - return stbi__err("expected marker","Corrupt JPEG"); - - case 0xDD: // DRI - specify restart interval - if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); - z->restart_interval = stbi__get16be(z->s); - return 1; - - case 0xDB: // DQT - define quantization table - L = stbi__get16be(z->s)-2; - while (L > 0) { - int q = stbi__get8(z->s); - int p = q >> 4, sixteen = (p != 0); - int t = q & 15,i; - if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); - if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); - - for (i=0; i < 64; ++i) - z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); - L -= (sixteen ? 129 : 65); - } - return L==0; - - case 0xC4: // DHT - define huffman table - L = stbi__get16be(z->s)-2; - while (L > 0) { - stbi_uc *v; - int sizes[16],i,n=0; - int q = stbi__get8(z->s); - int tc = q >> 4; - int th = q & 15; - if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); - for (i=0; i < 16; ++i) { - sizes[i] = stbi__get8(z->s); - n += sizes[i]; - } - L -= 17; - if (tc == 0) { - if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; - v = z->huff_dc[th].values; - } else { - if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; - v = z->huff_ac[th].values; + return 1; + } else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc( + z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } } - for (i=0; i < n; ++i) - v[i] = stbi__get8(z->s); - if (tc != 0) - stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); - L -= n; - } - return L==0; - } - - // check for comment block or APP blocks - if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { - L = stbi__get16be(z->s); - if (L < 2) { - if (m == 0xFE) - return stbi__err("bad COM len","Corrupt JPEG"); - else - return stbi__err("bad APP len","Corrupt JPEG"); + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } } - L -= 2; - - if (m == 0xE0 && L >= 5) { // JFIF APP0 segment - static const unsigned char tag[5] = {'J','F','I','F','\0'}; - int ok = 1; - int i; - for (i=0; i < 5; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 5; - if (ok) - z->jfif = 1; - } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment - static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; - int ok = 1; - int i; - for (i=0; i < 6; ++i) - if (stbi__get8(z->s) != tag[i]) - ok = 0; - L -= 6; - if (ok) { - stbi__get8(z->s); // version - stbi__get16be(z->s); // flags0 - stbi__get16be(z->s); // flags1 - z->app14_color_transform = stbi__get8(z->s); // color transform - L -= 6; - } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short *data = + z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + + i * 8, + z->img_comp[n].w2, data); + } } + } + } +} - stbi__skip(z->s, L); - return 1; - } +static int stbi__process_marker(stbi__jpeg *z, int m) { + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker", "Corrupt JPEG"); - return stbi__err("unknown marker","Corrupt JPEG"); -} + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) + return stbi__err("bad DRI len", "Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; -// after we see SOS -static int stbi__process_scan_header(stbi__jpeg *z) -{ - int i; - int Ls = stbi__get16be(z->s); - z->scan_n = stbi__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); - if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); - for (i=0; i < z->scan_n; ++i) { - int id = stbi__get8(z->s), which; + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s) - 2; + while (L > 0) { int q = stbi__get8(z->s); - for (which = 0; which < z->s->img_n; ++which) - if (z->img_comp[which].id == id) - break; - if (which == z->s->img_n) return 0; // no match - z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); - z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); - z->order[i] = which; - } - - { - int aa; - z->spec_start = stbi__get8(z->s); - z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 - aa = stbi__get8(z->s); - z->succ_high = (aa >> 4); - z->succ_low = (aa & 15); - if (z->progressive) { - if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return stbi__err("bad SOS", "Corrupt JPEG"); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) + return stbi__err("bad DQT type", "Corrupt JPEG"); + if (t > 3) + return stbi__err("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = + (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + stbi_uc *v; + int sizes[16], i, n = 0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) + return stbi__err("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc + th, sizes)) + return 0; + v = z->huff_dc[th].values; } else { - if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); - z->spec_end = 63; + if (!stbi__build_huffman(z->huff_ac + th, sizes)) + return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len", "Corrupt JPEG"); + else + return stbi__err("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; } - } + } + + stbi__skip(z->s, L); + return 1; + } - return 1; + return stbi__err("unknown marker", "Corrupt JPEG"); } -static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) -{ - int i; - for (i=0; i < ncomp; ++i) { - if (z->img_comp[i].raw_data) { - STBI_FREE(z->img_comp[i].raw_data); - z->img_comp[i].raw_data = NULL; - z->img_comp[i].data = NULL; - } - if (z->img_comp[i].raw_coeff) { - STBI_FREE(z->img_comp[i].raw_coeff); - z->img_comp[i].raw_coeff = 0; - z->img_comp[i].coeff = 0; - } - if (z->img_comp[i].linebuf) { - STBI_FREE(z->img_comp[i].linebuf); - z->img_comp[i].linebuf = NULL; - } - } - return why; +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) { + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) + return stbi__err("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) + return stbi__err("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) + return 0; // no match + z->img_comp[which].hd = q >> 4; + if (z->img_comp[which].hd > 3) + return stbi__err("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; + if (z->img_comp[which].ha > 3) + return stbi__err("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || + z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; } -static int stbi__process_frame_header(stbi__jpeg *z, int scan) -{ - stbi__context *s = z->s; - int Lf,p,i,q, h_max=1,v_max=1,c; - Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG - p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - c = stbi__get8(s); - if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); - s->img_n = c; - for (i=0; i < c; ++i) { +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; z->img_comp[i].data = NULL; - z->img_comp[i].linebuf = NULL; - } - - if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); - - z->rgb = 0; - for (i=0; i < s->img_n; ++i) { - static const unsigned char rgb[3] = { 'R', 'G', 'B' }; - z->img_comp[i].id = stbi__get8(s); - if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) - ++z->rgb; - q = stbi__get8(s); - z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); - z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); - z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); - } - - if (scan != STBI__SCAN_load) return 1; - - if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); - - for (i=0; i < s->img_n; ++i) { - if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; - if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; - } - - // compute interleaved mcu info - z->img_h_max = h_max; - z->img_v_max = v_max; - z->img_mcu_w = h_max * 8; - z->img_mcu_h = v_max * 8; - // these sizes can't be more than 17 bits - z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; - z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; - - for (i=0; i < s->img_n; ++i) { - // number of effective pixels (e.g. for non-interleaved MCU) - z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; - z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; - // to simplify generation, we'll allocate enough memory to decode - // the bogus oversized data from using interleaved MCUs and their - // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't - // discard the extra data until colorspace conversion - // - // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) - // so these muls can't overflow with 32-bit ints (which we require) - z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; - z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; - z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); z->img_comp[i].linebuf = NULL; - z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); - if (z->img_comp[i].raw_data == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - // align blocks for idct using mmx/sse - z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); - if (z->progressive) { - // w2, h2 are multiples of 8 (see above) - z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; - z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; - z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); - if (z->img_comp[i].raw_coeff == NULL) - return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); - z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); - } - } + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) { + stbi__context *s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = stbi__get16be(s); + if (Lf < 11) + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + p = stbi__get8(s); + if (p != 8) + return stbi__err("only 8-bit", + "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); + if (s->img_y == 0) + return stbi__err( + "no header height", + "JPEG format not supported: delayed height"); // Legal, but we don't + // handle it--but neither + // does IJG + s->img_x = stbi__get16be(s); + if (s->img_x == 0) + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) + return stbi__err("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) + return stbi__err("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = {'R', 'G', 'B'}; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); + if (!z->img_comp[i].h || z->img_comp[i].h > 4) + return stbi__err("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; + if (!z->img_comp[i].v || z->img_comp[i].v > 4) + return stbi__err("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); + if (z->img_comp[i].tq > 3) + return stbi__err("bad TQ", "Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) + return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) + return stbi__err("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) + h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) + v_max = z->img_comp[i].v; + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked + // earlier) so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = + stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i + 1, + stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = + (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3( + z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components( + z, i + 1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = + (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } - return 1; + return 1; } // use comparisons since in some cases we handle more than one case (e.g. SOF) -#define stbi__DNL(x) ((x) == 0xdc) -#define stbi__SOI(x) ((x) == 0xd8) -#define stbi__EOI(x) ((x) == 0xd9) -#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) -#define stbi__SOS(x) ((x) == 0xda) - -#define stbi__SOF_progressive(x) ((x) == 0xc2) - -static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) -{ - int m; - z->jfif = 0; - z->app14_color_transform = -1; // valid values are 0,1,2 - z->marker = STBI__MARKER_none; // initialize cached marker to empty - m = stbi__get_marker(z); - if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); - if (scan == STBI__SCAN_type) return 1; - m = stbi__get_marker(z); - while (!stbi__SOF(m)) { - if (!stbi__process_marker(z,m)) return 0; +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) + return stbi__err("no SOI", "Corrupt JPEG"); + if (scan == STBI__SCAN_type) + return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z, m)) + return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) + return stbi__err("no SOF", "Corrupt JPEG"); m = stbi__get_marker(z); - while (m == STBI__MARKER_none) { - // some files have extra padding after their blocks, so ok, we'll scan - if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); - m = stbi__get_marker(z); - } - } - z->progressive = stbi__SOF_progressive(m); - if (!stbi__process_frame_header(z, scan)) return 0; - return 1; + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) + return 0; + return 1; } // decode image to YCbCr format -static int stbi__decode_jpeg_image(stbi__jpeg *j) -{ - int m; - for (m = 0; m < 4; m++) { - j->img_comp[m].raw_data = NULL; - j->img_comp[m].raw_coeff = NULL; - } - j->restart_interval = 0; - if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; - m = stbi__get_marker(j); - while (!stbi__EOI(m)) { - if (stbi__SOS(m)) { - if (!stbi__process_scan_header(j)) return 0; - if (!stbi__parse_entropy_coded_data(j)) return 0; - if (j->marker == STBI__MARKER_none ) { - // handle 0s at the end of image data from IP Kamera 9060 - while (!stbi__at_eof(j->s)) { - int x = stbi__get8(j->s); - if (x == 255) { - j->marker = stbi__get8(j->s); - break; - } - } - // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 - } - } else if (stbi__DNL(m)) { - int Ld = stbi__get16be(j->s); - stbi__uint32 NL = stbi__get16be(j->s); - if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); - } else { - if (!stbi__process_marker(j, m)) return 0; +static int stbi__decode_jpeg_image(stbi__jpeg *j) { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) + return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) + return 0; + if (!stbi__parse_entropy_coded_data(j)) + return 0; + if (j->marker == STBI__MARKER_none) { + // handle 0s at the end of image data from IP Kamera 9060 + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + if (x == 255) { + j->marker = stbi__get8(j->s); + break; + } + } + // if we reach eof without hitting a marker, stbi__get_marker() below + // will fail and we'll eventually return 0 } - m = stbi__get_marker(j); - } - if (j->progressive) - stbi__jpeg_finish(j); - return 1; + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) + return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) + return stbi__err("bad DNL height", "Corrupt JPEG"); + } else { + if (!stbi__process_marker(j, m)) + return 0; + } + m = stbi__get_marker(j); + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; } // static jfif-centered resampling (across block boundaries) typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, - int w, int hs); - -#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) - -static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - STBI_NOTUSED(out); - STBI_NOTUSED(in_far); - STBI_NOTUSED(w); - STBI_NOTUSED(hs); - return in_near; -} - -static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples vertically for every one in input - int i; - STBI_NOTUSED(hs); - for (i=0; i < w; ++i) - out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); - return out; + int w, int hs); + +#define stbi__div4(x) ((stbi_uc)((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, + int w, int hs) { + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc *stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc *stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = stbi__div4(n + input[i - 1]); + out[i * 2 + 1] = stbi__div4(n + input[i + 1]); + } + out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc)((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = stbi__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; } -static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate two samples horizontally for every one in input - int i; - stbi_uc *input = in_near; - - if (w == 1) { - // if only one sample, can't do any interpolation - out[0] = out[1] = input[0]; - return out; - } - - out[0] = input[0]; - out[1] = stbi__div4(input[0]*3 + input[1] + 2); - for (i=1; i < w-1; ++i) { - int n = 3*input[i]+2; - out[i*2+0] = stbi__div4(n+input[i-1]); - out[i*2+1] = stbi__div4(n+input[i+1]); - } - out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); - out[i*2+1] = input[w-1]; - - STBI_NOTUSED(in_far); - STBI_NOTUSED(hs); - - return out; -} +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = + _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *)(out + i * 2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = + vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif -#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } -static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i,t0,t1; - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); - t1 = 3*in_near[0] + in_far[0]; - out[0] = stbi__div4(t1+2); - for (i=1; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); - STBI_NOTUSED(hs); + STBI_NOTUSED(hs); - return out; + return out; } +#endif -#if defined(STBI_SSE2) || defined(STBI_NEON) -static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // need to generate 2x2 samples for every one in input - int i=0,t0,t1; - - if (w == 1) { - out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); - return out; - } - - t1 = 3*in_near[0] + in_far[0]; - // process groups of 8 pixels for as long as we can. - // note we can't handle the last pixel in a row in this loop - // because we need to handle the filter boundary conditions. - for (; i < ((w-1) & ~7); i += 8) { -#if defined(STBI_SSE2) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - __m128i zero = _mm_setzero_si128(); - __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); - __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); - __m128i farw = _mm_unpacklo_epi8(farb, zero); - __m128i nearw = _mm_unpacklo_epi8(nearb, zero); - __m128i diff = _mm_sub_epi16(farw, nearw); - __m128i nears = _mm_slli_epi16(nearw, 2); - __m128i curr = _mm_add_epi16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - __m128i prv0 = _mm_slli_si128(curr, 2); - __m128i nxt0 = _mm_srli_si128(curr, 2); - __m128i prev = _mm_insert_epi16(prv0, t1, 0); - __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - __m128i bias = _mm_set1_epi16(8); - __m128i curs = _mm_slli_epi16(curr, 2); - __m128i prvd = _mm_sub_epi16(prev, curr); - __m128i nxtd = _mm_sub_epi16(next, curr); - __m128i curb = _mm_add_epi16(curs, bias); - __m128i even = _mm_add_epi16(prvd, curb); - __m128i odd = _mm_add_epi16(nxtd, curb); - - // interleave even and odd pixels, then undo scaling. - __m128i int0 = _mm_unpacklo_epi16(even, odd); - __m128i int1 = _mm_unpackhi_epi16(even, odd); - __m128i de0 = _mm_srli_epi16(int0, 4); - __m128i de1 = _mm_srli_epi16(int1, 4); - - // pack and write output - __m128i outv = _mm_packus_epi16(de0, de1); - _mm_storeu_si128((__m128i *) (out + i*2), outv); -#elif defined(STBI_NEON) - // load and perform the vertical filtering pass - // this uses 3*x + y = 4*x + (y - x) - uint8x8_t farb = vld1_u8(in_far + i); - uint8x8_t nearb = vld1_u8(in_near + i); - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); - int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); - int16x8_t curr = vaddq_s16(nears, diff); // current row - - // horizontal filter works the same based on shifted vers of current - // row. "prev" is current row shifted right by 1 pixel; we need to - // insert the previous pixel value (from t1). - // "next" is current row shifted left by 1 pixel, with first pixel - // of next block of 8 pixels added in. - int16x8_t prv0 = vextq_s16(curr, curr, 7); - int16x8_t nxt0 = vextq_s16(curr, curr, 1); - int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); - int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); - - // horizontal filter, polyphase implementation since it's convenient: - // even pixels = 3*cur + prev = cur*4 + (prev - cur) - // odd pixels = 3*cur + next = cur*4 + (next - cur) - // note the shared term. - int16x8_t curs = vshlq_n_s16(curr, 2); - int16x8_t prvd = vsubq_s16(prev, curr); - int16x8_t nxtd = vsubq_s16(next, curr); - int16x8_t even = vaddq_s16(curs, prvd); - int16x8_t odd = vaddq_s16(curs, nxtd); - - // undo scaling and round, then store with even/odd phases interleaved - uint8x8x2_t o; - o.val[0] = vqrshrun_n_s16(even, 4); - o.val[1] = vqrshrun_n_s16(odd, 4); - vst2_u8(out + i*2, o); -#endif - - // "previous" value for next iter - t1 = 3*in_near[i+7] + in_far[i+7]; - } - - t0 = t1; - t1 = 3*in_near[i] + in_far[i]; - out[i*2] = stbi__div16(3*t1 + t0 + 8); - - for (++i; i < w; ++i) { - t0 = t1; - t1 = 3*in_near[i]+in_far[i]; - out[i*2-1] = stbi__div16(3*t0 + t1 + 8); - out[i*2 ] = stbi__div16(3*t1 + t0 + 8); - } - out[w*2-1] = stbi__div4(t1+2); - - STBI_NOTUSED(hs); - - return out; -} -#endif - -static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) -{ - // resample with nearest-neighbor - int i,j; - STBI_NOTUSED(in_far); - for (i=0; i < w; ++i) - for (j=0; j < hs; ++j) - out[i*hs+j] = in_near[i]; - return out; +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, + stbi_uc *in_far, int w, int hs) { + // resample with nearest-neighbor + int i, j; + STBI_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; } // this is a reduced-precision calculation of YCbCr-to-RGB introduced // to make sure the code produces the same results in both SIMD and scalar -#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) -static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) -{ - int i; - for (i=0; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } +#define stbi__float2fixed(x) (((int)((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, + const stbi_uc *pcb, const stbi_uc *pcr, + int count, int step) { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #if defined(STBI_SSE2) || defined(STBI_NEON) -static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) -{ - int i = 0; +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, + stbi_uc const *pcb, stbi_uc const *pcr, + int count, int step) { + int i = 0; #ifdef STBI_SSE2 - // step == 3 is pretty ugly on the final interleave, and i'm not convinced - // it's useful in practice (you wouldn't use it for textures, for example). - // so just accelerate step == 4 case. - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - __m128i signflip = _mm_set1_epi8(-0x80); - __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); - __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); - __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); - __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); - __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); - __m128i xw = _mm_set1_epi16(255); // alpha channel - - for (; i+7 < count; i += 8) { - // load - __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); - __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); - __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); - __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 - __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 - - // unpack to short (and left-shift cr, cb by 8) - __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); - __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); - __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); - - // color transform - __m128i yws = _mm_srli_epi16(yw, 4); - __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); - __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); - __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); - __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); - __m128i rws = _mm_add_epi16(cr0, yws); - __m128i gwt = _mm_add_epi16(cb0, yws); - __m128i bws = _mm_add_epi16(yws, cb1); - __m128i gws = _mm_add_epi16(gwt, cr1); - - // descale - __m128i rw = _mm_srai_epi16(rws, 4); - __m128i bw = _mm_srai_epi16(bws, 4); - __m128i gw = _mm_srai_epi16(gws, 4); - - // back to byte, set up for transpose - __m128i brb = _mm_packus_epi16(rw, bw); - __m128i gxb = _mm_packus_epi16(gw, xw); - - // transpose to interleave channels - __m128i t0 = _mm_unpacklo_epi8(brb, gxb); - __m128i t1 = _mm_unpackhi_epi8(brb, gxb); - __m128i o0 = _mm_unpacklo_epi16(t0, t1); - __m128i o1 = _mm_unpackhi_epi16(t0, t1); - - // store - _mm_storeu_si128((__m128i *) (out + 0), o0); - _mm_storeu_si128((__m128i *) (out + 16), o1); - out += 32; - } - } + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *)(out + 0), o0); + _mm_storeu_si128((__m128i *)(out + 16), o1); + out += 32; + } + } #endif #ifdef STBI_NEON - // in this version, step=3 support would be easy to add. but is there demand? - if (step == 4) { - // this is a fairly straightforward implementation and not super-optimized. - uint8x8_t signflip = vdup_n_u8(0x80); - int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); - int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); - int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); - int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); - - for (; i+7 < count; i += 8) { - // load - uint8x8_t y_bytes = vld1_u8(y + i); - uint8x8_t cr_bytes = vld1_u8(pcr + i); - uint8x8_t cb_bytes = vld1_u8(pcb + i); - int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); - int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); - - // expand to s16 - int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); - int16x8_t crw = vshll_n_s8(cr_biased, 7); - int16x8_t cbw = vshll_n_s8(cb_biased, 7); - - // color transform - int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); - int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); - int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); - int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); - int16x8_t rws = vaddq_s16(yws, cr0); - int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); - int16x8_t bws = vaddq_s16(yws, cb1); - - // undo scaling, round, convert to byte - uint8x8x4_t o; - o.val[0] = vqrshrun_n_s16(rws, 4); - o.val[1] = vqrshrun_n_s16(gws, 4); - o.val[2] = vqrshrun_n_s16(bws, 4); - o.val[3] = vdup_n_u8(255); - - // store, interleaving r/g/b/a - vst4_u8(out, o); - out += 8*4; - } - } -#endif - - for (; i < count; ++i) { - int y_fixed = (y[i] << 20) + (1<<19); // rounding - int r,g,b; - int cr = pcr[i] - 128; - int cb = pcb[i] - 128; - r = y_fixed + cr* stbi__float2fixed(1.40200f); - g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); - b = y_fixed + cb* stbi__float2fixed(1.77200f); - r >>= 20; - g >>= 20; - b >>= 20; - if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } - if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } - if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } - out[0] = (stbi_uc)r; - out[1] = (stbi_uc)g; - out[2] = (stbi_uc)b; - out[3] = 255; - out += step; - } + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + cr * -stbi__float2fixed(0.71414f) + + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } } #endif // set up the kernels -static void stbi__setup_jpeg(stbi__jpeg *j) -{ - j->idct_block_kernel = stbi__idct_block; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; +static void stbi__setup_jpeg(stbi__jpeg *j) { + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; #ifdef STBI_SSE2 - if (stbi__sse2_available()) { - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; - } + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } #endif #ifdef STBI_NEON - j->idct_block_kernel = stbi__idct_simd; - j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; - j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; #endif } // clean up the temporary component buffers -static void stbi__cleanup_jpeg(stbi__jpeg *j) -{ - stbi__free_jpeg_components(j, j->s->img_n, 0); +static void stbi__cleanup_jpeg(stbi__jpeg *j) { + stbi__free_jpeg_components(j, j->s->img_n, 0); } -typedef struct -{ - resample_row_func resample; - stbi_uc *line0,*line1; - int hs,vs; // expansion factor in each axis - int w_lores; // horizontal pixels pre-expansion - int ystep; // how far through vertical expansion we are - int ypos; // which pre-expansion row we're on +typedef struct { + resample_row_func resample; + stbi_uc *line0, *line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on } stbi__resample; // fast 0..255 * 0..255 => 0..255 rounded multiplication -static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) -{ - unsigned int t = x*y + 128; - return (stbi_uc) ((t + (t >>8)) >> 8); -} - -static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) -{ - int n, decode_n, is_rgb; - z->s->img_n = 0; // make stbi__cleanup_jpeg safe - - // validate req_comp - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - - // load a jpeg image from whichever source, but leave in YCbCr format - if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } - - // determine actual number of components to generate - n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; - - is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); - - if (z->s->img_n == 3 && n < 3 && !is_rgb) - decode_n = 1; - else - decode_n = z->s->img_n; - - // resample and color-convert - { - int k; - unsigned int i,j; - stbi_uc *output; - stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; - - stbi__resample res_comp[4]; - - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - - // allocate line buffer big enough for upsampling off the edges - // with upsample factor of 4 - z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); - if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - r->hs = z->img_h_max / z->img_comp[k].h; - r->vs = z->img_v_max / z->img_comp[k].v; - r->ystep = r->vs >> 1; - r->w_lores = (z->s->img_x + r->hs-1) / r->hs; - r->ypos = 0; - r->line0 = r->line1 = z->img_comp[k].data; - - if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; - else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; - else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; - else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; - else r->resample = stbi__resample_row_generic; +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { + unsigned int t = x * y + 128; + return (stbi_uc)((t + (t >> 8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, + int *comp, int req_comp) { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && + (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // resample and color-convert + { + int k; + unsigned int i, j; + stbi_uc *output; + stbi_uc *coutput[4] = {NULL, NULL, NULL, NULL}; + + stbi__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); } - // can't error after this so, this is safe - output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); - if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } - - // now go ahead and resample - for (j=0; j < z->s->img_y; ++j) { - stbi_uc *out = output + n * z->s->img_x * j; - for (k=0; k < decode_n; ++k) { - stbi__resample *r = &res_comp[k]; - int y_bot = r->ystep >= (r->vs >> 1); - coutput[k] = r->resample(z->img_comp[k].linebuf, - y_bot ? r->line1 : r->line0, - y_bot ? r->line0 : r->line1, - r->w_lores, r->hs); - if (++r->ystep >= r->vs) { - r->ystep = 0; - r->line0 = r->line1; - if (++r->ypos < z->img_comp[k].y) - r->line1 += z->img_comp[k].w2; + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) + r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) + r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) + r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) + r->resample = z->resample_row_hv_2_kernel; + else + r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = + r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; } - } - if (n >= 3) { - stbi_uc *y = coutput[0]; - if (z->s->img_n == 3) { - if (is_rgb) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = y[i]; - out[1] = coutput[1][i]; - out[2] = coutput[2][i]; - out[3] = 255; - out += n; - } - } else { - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else if (z->s->img_n == 4) { - if (z->app14_color_transform == 0) { // CMYK - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(coutput[0][i], m); - out[1] = stbi__blinn_8x8(coutput[1][i], m); - out[2] = stbi__blinn_8x8(coutput[2][i], m); - out[3] = 255; - out += n; - } - } else if (z->app14_color_transform == 2) { // YCCK - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - out[0] = stbi__blinn_8x8(255 - out[0], m); - out[1] = stbi__blinn_8x8(255 - out[1], m); - out[2] = stbi__blinn_8x8(255 - out[2], m); - out += n; - } - } else { // YCbCr + alpha? Ignore the fourth channel for now - z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); - } - } else - for (i=0; i < z->s->img_x; ++i) { - out[0] = out[1] = out[2] = y[i]; - out[3] = 255; // not used if n==3 - out += n; - } - } else { - if (is_rgb) { - if (n == 1) - for (i=0; i < z->s->img_x; ++i) - *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - else { - for (i=0; i < z->s->img_x; ++i, out += 2) { - out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); - out[1] = 255; - } - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { - for (i=0; i < z->s->img_x; ++i) { - stbi_uc m = coutput[3][i]; - stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); - stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); - stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); - out[0] = stbi__compute_y(r, g, b); - out[1] = 255; - out += n; - } - } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { - for (i=0; i < z->s->img_x; ++i) { - out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); - out[1] = 255; - out += n; - } - } else { - stbi_uc *y = coutput[0]; - if (n == 1) - for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; - else - for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; } - } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, + n); + } + } else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = + stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { + *out++ = y[i]; + *out++ = 255; + } + } } - stbi__cleanup_jpeg(z); - *out_x = z->s->img_x; - *out_y = z->s->img_y; - if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output - return output; - } -} - -static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - unsigned char* result; - stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); - STBI_NOTUSED(ri); - j->s = s; - stbi__setup_jpeg(j); - result = load_jpeg_image(j, x,y,comp,req_comp); - STBI_FREE(j); - return result; -} - -static int stbi__jpeg_test(stbi__context *s) -{ - int r; - stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); - j->s = s; - stbi__setup_jpeg(j); - r = stbi__decode_jpeg_header(j, STBI__SCAN_type); - stbi__rewind(s); - STBI_FREE(j); - return r; -} - -static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) -{ - if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { - stbi__rewind( j->s ); - return 0; - } - if (x) *x = j->s->img_x; - if (y) *y = j->s->img_y; - if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; - return 1; -} - -static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) -{ - int result; - stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); - j->s = s; - result = stbi__jpeg_info_raw(j, x, y, comp); - STBI_FREE(j); - return result; + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) + *comp = + z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + unsigned char *result; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) { + int r; + stbi__jpeg *j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) { + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind(j->s); + return 0; + } + if (x) + *x = j->s->img_x; + if (y) + *y = j->s->img_y; + if (comp) + *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) { + int result; + stbi__jpeg *j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; } #endif @@ -3977,83 +4416,81 @@ static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_ZLIB // fast-way is faster to check than jpeg huffman, but slow way is slower -#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables -#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) // zlib-style huffman encoding // (jpegs packs from left, zlib from right, so can't share code) -typedef struct -{ - stbi__uint16 fast[1 << STBI__ZFAST_BITS]; - stbi__uint16 firstcode[16]; - int maxcode[17]; - stbi__uint16 firstsymbol[16]; - stbi_uc size[288]; - stbi__uint16 value[288]; +typedef struct { + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[288]; + stbi__uint16 value[288]; } stbi__zhuffman; -stbi_inline static int stbi__bitreverse16(int n) -{ - n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); - n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); - n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); - n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); +stbi_inline static int stbi__bitreverse16(int n) { + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); return n; } -stbi_inline static int stbi__bit_reverse(int v, int bits) -{ - STBI_ASSERT(bits <= 16); - // to bit reverse n bits, reverse 16 and shift - // e.g. 11 bits, bit reverse and shift away 5 - return stbi__bitreverse16(v) >> (16-bits); -} - -static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) -{ - int i,k=0; - int code, next_code[16], sizes[17]; - - // DEFLATE spec for generating codes - memset(sizes, 0, sizeof(sizes)); - memset(z->fast, 0, sizeof(z->fast)); - for (i=0; i < num; ++i) - ++sizes[sizelist[i]]; - sizes[0] = 0; - for (i=1; i < 16; ++i) - if (sizes[i] > (1 << i)) - return stbi__err("bad sizes", "Corrupt PNG"); - code = 0; - for (i=1; i < 16; ++i) { - next_code[i] = code; - z->firstcode[i] = (stbi__uint16) code; - z->firstsymbol[i] = (stbi__uint16) k; - code = (code + sizes[i]); - if (sizes[i]) - if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); - z->maxcode[i] = code << (16-i); // preshift for inner loop - code <<= 1; - k += sizes[i]; - } - z->maxcode[16] = 0x10000; // sentinel - for (i=0; i < num; ++i) { - int s = sizelist[i]; - if (s) { - int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; - stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); - z->size [c] = (stbi_uc ) s; - z->value[c] = (stbi__uint16) i; - if (s <= STBI__ZFAST_BITS) { - int j = stbi__bit_reverse(next_code[s],s); - while (j < (1 << STBI__ZFAST_BITS)) { - z->fast[j] = fastv; - j += (1 << s); - } - } - ++next_code[s]; +stbi_inline static int stbi__bit_reverse(int v, int bits) { + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16 - bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, + int num) { + int i, k = 0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i = 0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i = 1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i = 1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16)code; + z->firstsymbol[i] = (stbi__uint16)k; + code = (code + sizes[i]); + if (sizes[i]) + if (code - 1 >= (1 << i)) + return stbi__err("bad codelengths", "Corrupt PNG"); + z->maxcode[i] = code << (16 - i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i = 0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); + z->size[c] = (stbi_uc)s; + z->value[c] = (stbi__uint16)i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s], s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } } - } - return 1; + ++next_code[s]; + } + } + return 1; } // zlib-from-memory implementation for PNG reading @@ -4062,277 +4499,313 @@ static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int // we require PNG read all the IDATs and combine them into a single // memory buffer -typedef struct -{ - stbi_uc *zbuffer, *zbuffer_end; - int num_bits; - stbi__uint32 code_buffer; +typedef struct { + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; - char *zout; - char *zout_start; - char *zout_end; - int z_expandable; + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; - stbi__zhuffman z_length, z_distance; + stbi__zhuffman z_length, z_distance; } stbi__zbuf; -stbi_inline static int stbi__zeof(stbi__zbuf *z) -{ - return (z->zbuffer >= z->zbuffer_end); -} - -stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) -{ - return stbi__zeof(z) ? 0 : *z->zbuffer++; -} - -static void stbi__fill_bits(stbi__zbuf *z) -{ - do { - if (z->code_buffer >= (1U << z->num_bits)) { - z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ - return; - } - z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; - z->num_bits += 8; - } while (z->num_bits <= 24); +stbi_inline static int stbi__zeof(stbi__zbuf *z) { + return (z->zbuffer >= z->zbuffer_end); } -stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) -{ - unsigned int k; - if (z->num_bits < n) stbi__fill_bits(z); - k = z->code_buffer & ((1 << n) - 1); - z->code_buffer >>= n; - z->num_bits -= n; - return k; +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) { + return stbi__zeof(z) ? 0 : *z->zbuffer++; } -static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s,k; - // not resolved by fast table, so compute it the slow way - // use jpeg approach, which requires MSbits at top - k = stbi__bit_reverse(a->code_buffer, 16); - for (s=STBI__ZFAST_BITS+1; ; ++s) - if (k < z->maxcode[s]) - break; - if (s >= 16) return -1; // invalid code! - // code size is s, so: - b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; - if (b >= sizeof (z->size)) return -1; // some data was corrupt somewhere! - if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. - a->code_buffer >>= s; - a->num_bits -= s; - return z->value[b]; -} - -stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) -{ - int b,s; - if (a->num_bits < 16) { - if (stbi__zeof(a)) { - return -1; /* report error for unexpected end of data. */ - } - stbi__fill_bits(a); - } - b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; - if (b) { - s = b >> 9; - a->code_buffer >>= s; - a->num_bits -= s; - return b & 511; - } - return stbi__zhuffman_decode_slowpath(a, z); -} - -static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes -{ - char *q; - unsigned int cur, limit, old_limit; - z->zout = zout; - if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); - cur = (unsigned int) (z->zout - z->zout_start); - limit = old_limit = (unsigned) (z->zout_end - z->zout_start); - if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); - while (cur + n > limit) { - if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); - limit *= 2; - } - q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); - STBI_NOTUSED(old_limit); - if (q == NULL) return stbi__err("outofmem", "Out of memory"); - z->zout_start = q; - z->zout = q + cur; - z->zout_end = q + limit; - return 1; +static void stbi__fill_bits(stbi__zbuf *z) { + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) { + unsigned int k; + if (z->num_bits < n) + stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s, k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s = STBI__ZFAST_BITS + 1;; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) + return -1; // invalid code! + // code size is s, so: + b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= sizeof(z->size)) + return -1; // some data was corrupt somewhere! + if (z->size[b] != s) + return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) { + int b, s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, + int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) + return stbi__err("output buffer limit", "Corrupt PNG"); + cur = (unsigned int)(z->zout - z->zout_start); + limit = old_limit = (unsigned)(z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned)n) + return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if (limit > UINT_MAX / 2) + return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) + return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; } static const int stbi__zlength_base[31] = { - 3,4,5,6,7,8,9,10,11,13, - 15,17,19,23,27,31,35,43,51,59, - 67,83,99,115,131,163,195,227,258,0,0 }; - -static const int stbi__zlength_extra[31]= -{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; - -static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, -257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; - -static const int stbi__zdist_extra[32] = -{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; - -static int stbi__parse_huffman_block(stbi__zbuf *a) -{ - char *zout = a->zout; - for(;;) { - int z = stbi__zhuffman_decode(a, &a->z_length); - if (z < 256) { - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes - if (zout >= a->zout_end) { - if (!stbi__zexpand(a, zout, 1)) return 0; - zout = a->zout; - } - *zout++ = (char) z; + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + +static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, + 4, 4, 5, 5, 5, 5, 0, 0, 0}; + +static const int stbi__zdist_base[32] = { + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; + +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) { + char *zout = a->zout; + for (;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) + return stbi__err("bad huffman code", + "Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) + return 0; + zout = a->zout; + } + *zout++ = (char)z; + } else { + stbi_uc *p; + int len, dist; + if (z == 256) { + a->zout = zout; + return 1; + } + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) + len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0) + return stbi__err("bad huffman code", "Corrupt PNG"); + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) + dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) + return stbi__err("bad dist", "Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) + return 0; + zout = a->zout; + } + p = (stbi_uc *)(zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { + do + *zout++ = v; + while (--len); + } } else { - stbi_uc *p; - int len,dist; - if (z == 256) { - a->zout = zout; - return 1; - } - z -= 257; - len = stbi__zlength_base[z]; - if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); - z = stbi__zhuffman_decode(a, &a->z_distance); - if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); - dist = stbi__zdist_base[z]; - if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); - if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); - if (zout + len > a->zout_end) { - if (!stbi__zexpand(a, zout, len)) return 0; - zout = a->zout; - } - p = (stbi_uc *) (zout - dist); - if (dist == 1) { // run of one byte; common in images. - stbi_uc v = *p; - if (len) { do *zout++ = v; while (--len); } - } else { - if (len) { do *zout++ = *p++; while (--len); } - } + if (len) { + do + *zout++ = *p++; + while (--len); + } } - } -} - -static int stbi__compute_huffman_codes(stbi__zbuf *a) -{ - static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; - stbi__zhuffman z_codelength; - stbi_uc lencodes[286+32+137];//padding for maximum single op - stbi_uc codelength_sizes[19]; - int i,n; - - int hlit = stbi__zreceive(a,5) + 257; - int hdist = stbi__zreceive(a,5) + 1; - int hclen = stbi__zreceive(a,4) + 4; - int ntot = hlit + hdist; - - memset(codelength_sizes, 0, sizeof(codelength_sizes)); - for (i=0; i < hclen; ++i) { - int s = stbi__zreceive(a,3); - codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; - } - if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; - - n = 0; - while (n < ntot) { - int c = stbi__zhuffman_decode(a, &z_codelength); - if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); - if (c < 16) - lencodes[n++] = (stbi_uc) c; - else { - stbi_uc fill = 0; - if (c == 16) { - c = stbi__zreceive(a,2)+3; - if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); - fill = lencodes[n-1]; - } else if (c == 17) { - c = stbi__zreceive(a,3)+3; - } else if (c == 18) { - c = stbi__zreceive(a,7)+11; - } else { - return stbi__err("bad codelengths", "Corrupt PNG"); - } - if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); - memset(lencodes+n, fill, c); - n += c; + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) { + static const stbi_uc length_dezigzag[19] = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc codelength_sizes[19]; + int i, n; + + int hlit = stbi__zreceive(a, 5) + 257; + int hdist = stbi__zreceive(a, 5) + 1; + int hclen = stbi__zreceive(a, 4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i = 0; i < hclen; ++i) { + int s = stbi__zreceive(a, 3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) + return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc)c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a, 2) + 3; + if (n == 0) + return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n - 1]; + } else if (c == 17) { + c = stbi__zreceive(a, 3) + 3; + } else if (c == 18) { + c = stbi__zreceive(a, 7) + 11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); } - } - if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); - if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; - return 1; -} - -static int stbi__parse_uncompressed_block(stbi__zbuf *a) -{ - stbi_uc header[4]; - int len,nlen,k; - if (a->num_bits & 7) - stbi__zreceive(a, a->num_bits & 7); // discard - // drain the bit-packed data into header - k = 0; - while (a->num_bits > 0) { - header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check - a->code_buffer >>= 8; - a->num_bits -= 8; - } - if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); - // now fill header the normal way - while (k < 4) - header[k++] = stbi__zget8(a); - len = header[1] * 256 + header[0]; - nlen = header[3] * 256 + header[2]; - if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); - if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); - if (a->zout + len > a->zout_end) - if (!stbi__zexpand(a, a->zout, len)) return 0; - memcpy(a->zout, a->zbuffer, len); - a->zbuffer += len; - a->zout += len; - return 1; -} - -static int stbi__parse_zlib_header(stbi__zbuf *a) -{ - int cmf = stbi__zget8(a); - int cm = cmf & 15; - /* int cinfo = cmf >> 4; */ - int flg = stbi__zget8(a); - if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec - if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png - if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png - // window = 1 << (8 + cinfo)... but who cares, we fully buffer output - return 1; -} - -static const stbi_uc stbi__zdefault_length[288] = -{ - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, - 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 -}; -static const stbi_uc stbi__zdefault_distance[32] = -{ - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 -}; + if (ntot - n < c) + return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes + n, fill, c); + n += c; + } + } + if (n != ntot) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) + return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) { + stbi_uc header[4]; + int len, nlen, k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = + (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) + return stbi__err("zlib corrupt", "Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) + return stbi__err("zlib corrupt", "Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) + return stbi__err("read past buffer", "Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) + return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) { + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if ((cmf * 256 + flg) % 31 != 0) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if (flg & 32) + return stbi__err("no preset dict", + "Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) + return stbi__err("bad compression", + "Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[288] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; +static const stbi_uc stbi__zdefault_distance[32] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; /* Init algorithm: { @@ -4346,117 +4819,131 @@ Init algorithm: } */ -static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) -{ - int final, type; - if (parse_header) - if (!stbi__parse_zlib_header(a)) return 0; - a->num_bits = 0; - a->code_buffer = 0; - do { - final = stbi__zreceive(a,1); - type = stbi__zreceive(a,2); - if (type == 0) { - if (!stbi__parse_uncompressed_block(a)) return 0; - } else if (type == 3) { - return 0; +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) { + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) + return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a, 1); + type = stbi__zreceive(a, 2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) + return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, 288)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) + return 0; } else { - if (type == 1) { - // use fixed code lengths - if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , 288)) return 0; - if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; - } else { - if (!stbi__compute_huffman_codes(a)) return 0; - } - if (!stbi__parse_huffman_block(a)) return 0; + if (!stbi__compute_huffman_codes(a)) + return 0; } - } while (!final); - return 1; -} - -static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) -{ - a->zout_start = obuf; - a->zout = obuf; - a->zout_end = obuf + olen; - a->z_expandable = exp; - - return stbi__parse_zlib(a, parse_header); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) -{ - return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); -} - -STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(initial_size); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer + len; - if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) - return (int) (a.zout - a.zout_start); - else - return -1; -} - -STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) -{ - stbi__zbuf a; - char *p = (char *) stbi__malloc(16384); - if (p == NULL) return NULL; - a.zbuffer = (stbi_uc *) buffer; - a.zbuffer_end = (stbi_uc *) buffer+len; - if (stbi__do_zlib(&a, p, 16384, 1, 0)) { - if (outlen) *outlen = (int) (a.zout - a.zout_start); - return a.zout_start; - } else { - STBI_FREE(a.zout_start); - return NULL; - } -} - -STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) -{ - stbi__zbuf a; - a.zbuffer = (stbi_uc *) ibuffer; - a.zbuffer_end = (stbi_uc *) ibuffer + ilen; - if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) - return (int) (a.zout - a.zout_start); - else - return -1; + if (!stbi__parse_huffman_block(a)) + return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, + int parse_header) { + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, + int initial_size, int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, + int *outlen) { + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, + int len, + int initial_size, + int *outlen, + int parse_header) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, + char const *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int)(a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, + int *outlen) { + stbi__zbuf a; + char *p = (char *)stbi__malloc(16384); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, + const char *ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int)(a.zout - a.zout_start); + else + return -1; } #endif @@ -4471,1083 +4958,1312 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char // - uses stb_zlib, a PD zlib implementation with fast huffman decoding #ifndef STBI_NO_PNG -typedef struct -{ - stbi__uint32 length; - stbi__uint32 type; +typedef struct { + stbi__uint32 length; + stbi__uint32 type; } stbi__pngchunk; -static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) -{ - stbi__pngchunk c; - c.length = stbi__get32be(s); - c.type = stbi__get32be(s); - return c; +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) { + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; } -static int stbi__check_png_header(stbi__context *s) -{ - static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; - int i; - for (i=0; i < 8; ++i) - if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); - return 1; +static int stbi__check_png_header(stbi__context *s) { + static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + int i; + for (i = 0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) + return stbi__err("bad png sig", "Not a PNG"); + return 1; } -typedef struct -{ - stbi__context *s; - stbi_uc *idata, *expanded, *out; - int depth; +typedef struct { + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; } stbi__png; - enum { - STBI__F_none=0, - STBI__F_sub=1, - STBI__F_up=2, - STBI__F_avg=3, - STBI__F_paeth=4, - // synthetic filters used for first scanline to avoid needing a dummy row of 0s - STBI__F_avg_first, - STBI__F_paeth_first + STBI__F_none = 0, + STBI__F_sub = 1, + STBI__F_up = 2, + STBI__F_avg = 3, + STBI__F_paeth = 4, + // synthetic filters used for first scanline to avoid needing a dummy row of + // 0s + STBI__F_avg_first, + STBI__F_paeth_first }; -static stbi_uc first_row_filter[5] = -{ - STBI__F_none, - STBI__F_sub, - STBI__F_none, - STBI__F_avg_first, - STBI__F_paeth_first -}; +static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, + STBI__F_avg_first, STBI__F_paeth_first}; -static int stbi__paeth(int a, int b, int c) -{ - int p = a + b - c; - int pa = abs(p-a); - int pb = abs(p-b); - int pc = abs(p-c); - if (pa <= pb && pa <= pc) return a; - if (pb <= pc) return b; - return c; +static int stbi__paeth(int a, int b, int c) { + int p = a + b - c; + int pa = abs(p - a); + int pb = abs(p - b); + int pc = abs(p - c); + if (pa <= pb && pa <= pc) + return a; + if (pb <= pc) + return b; + return c; } -static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; +static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, + 0, 0, 0, 0x01}; // create the png data from post-deflated data -static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) -{ - int bytes = (depth == 16? 2 : 1); - stbi__context *s = a->s; - stbi__uint32 i,j,stride = x*out_n*bytes; - stbi__uint32 img_len, img_width_bytes; - int k; - int img_n = s->img_n; // copy it into a local for later - - int output_bytes = out_n*bytes; - int filter_bytes = img_n*bytes; - int width = x; - - STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); - a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into - if (!a->out) return stbi__err("outofmem", "Out of memory"); - - if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); - img_width_bytes = (((img_n * x * depth) + 7) >> 3); - img_len = (img_width_bytes + 1) * y; - - // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, - // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), - // so just check for raw_len < img_len always. - if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); - - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *prior; - int filter = *raw++; - - if (filter > 4) - return stbi__err("invalid filter","Corrupt PNG"); - - if (depth < 8) { - if (img_width_bytes > x) return stbi__err("invalid width","Corrupt PNG"); - cur += x*out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - - // if first row, use special filter that doesn't sample previous row - if (j == 0) filter = first_row_filter[filter]; - - // handle first byte explicitly - for (k=0; k < filter_bytes; ++k) { - switch (filter) { - case STBI__F_none : cur[k] = raw[k]; break; - case STBI__F_sub : cur[k] = raw[k]; break; - case STBI__F_up : cur[k] = STBI__BYTECAST(raw[k] + prior[k]); break; - case STBI__F_avg : cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); break; - case STBI__F_paeth : cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0,prior[k],0)); break; - case STBI__F_avg_first : cur[k] = raw[k]; break; - case STBI__F_paeth_first: cur[k] = raw[k]; break; - } +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, + stbi__uint32 raw_len, int out_n, + stbi__uint32 x, stbi__uint32 y, int depth, + int color) { + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i, j, stride = x * out_n * bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n * bytes; + int filter_bytes = img_n * bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); + a->out = (stbi_uc *)stbi__malloc_mad3( + x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) + return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) + return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on + // non-interlaced PNGs, but issue #276 reported a PNG in the wild that had + // extra data at the end (all zeros), so just check for raw_len < img_len + // always. + if (raw_len < img_len) + return stbi__err("not enough pixels", "Corrupt PNG"); + + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter", "Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) + return stbi__err("invalid width", "Corrupt PNG"); + cur += + x * out_n - img_width_bytes; // store output to the rightmost img_len + // bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = + cur - + stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) + filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k = 0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none: + cur[k] = raw[k]; + break; + case STBI__F_sub: + cur[k] = raw[k]; + break; + case STBI__F_up: + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); + break; + case STBI__F_paeth: + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); + break; + case STBI__F_avg_first: + cur[k] = raw[k]; + break; + case STBI__F_paeth_first: + cur[k] = raw[k]; + break; } + } - if (depth == 8) { - if (img_n != out_n) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += out_n; - prior += out_n; - } else if (depth == 16) { - if (img_n != out_n) { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes+1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } else { - raw += 1; - cur += 1; - prior += 1; + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } - // this is a little gross, so that we don't switch per-pixel or per-component - if (depth < 8 || img_n == out_n) { - int nk = (width - 1)*filter_bytes; - #define STBI__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - // "none" filter turns into a memcpy here; make that explicit. - case STBI__F_none: memcpy(cur, raw, nk); break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],prior[k],prior[k-filter_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes],0,0)); } break; - } - #undef STBI__CASE - raw += nk; - } else { - STBI_ASSERT(img_n+1 == out_n); - #define STBI__CASE(f) \ - case f: \ - for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } break; - STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k- output_bytes]); } break; - STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } break; - STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k- output_bytes])>>1)); } break; - STBI__CASE(STBI__F_paeth) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],prior[k],prior[k- output_bytes])); } break; - STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k- output_bytes] >> 1)); } break; - STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k- output_bytes],0,0)); } break; - } - #undef STBI__CASE - - // the loop above sets the high byte of the pixels' alpha, but for - // 16 bit png files we also need the low byte set. we'll do that here. - if (depth == 16) { - cur = a->out + stride*j; // start at the beginning of the row again - for (i=0; i < x; ++i,cur+=output_bytes) { - cur[filter_bytes+1] = 255; - } - } + // this is a little gross, so that we don't switch per-pixel or + // per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1) * filter_bytes; +#define STBI__CASE(f) \ + case f: \ + for (k = 0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: + memcpy(cur, raw, nk); + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - filter_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - filter_bytes], prior[k], + prior[k - filter_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); + } + break; } - } - - // we make a separate pass to expand bits to pixels; for performance, - // this could run two scanlines behind the above code, so it won't - // intefere with filtering but will still be in the cache. - if (depth < 8) { - for (j=0; j < y; ++j) { - stbi_uc *cur = a->out + stride*j; - stbi_uc *in = a->out + stride*j + x*out_n - img_width_bytes; - // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for 1/2/4-bit - // png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that will be skipped in the later loop - stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range - - // note that the final byte might overshoot and write more data than desired. - // we can allocate enough data that this never writes out of memory, but it - // could also overwrite the next scanline. can it overwrite non-empty data - // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. - // so we need to explicitly clamp the final ones - - if (depth == 4) { - for (k=x*img_n; k >= 2; k-=2, ++in) { - *cur++ = scale * ((*in >> 4) ); - *cur++ = scale * ((*in ) & 0x0f); - } - if (k > 0) *cur++ = scale * ((*in >> 4) ); - } else if (depth == 2) { - for (k=x*img_n; k >= 4; k-=4, ++in) { - *cur++ = scale * ((*in >> 6) ); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in ) & 0x03); - } - if (k > 0) *cur++ = scale * ((*in >> 6) ); - if (k > 1) *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) *cur++ = scale * ((*in >> 2) & 0x03); - } else if (depth == 1) { - for (k=x*img_n; k >= 8; k-=8, ++in) { - *cur++ = scale * ((*in >> 7) ); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in ) & 0x01); - } - if (k > 0) *cur++ = scale * ((*in >> 7) ); - if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != out_n) { - int q; - // insert alpha = 255 - cur = a->out + stride*j; - if (img_n == 1) { - for (q=x-1; q >= 0; --q) { - cur[q*2+1] = 255; - cur[q*2+0] = cur[q]; - } - } else { - STBI_ASSERT(img_n == 3); - for (q=x-1; q >= 0; --q) { - cur[q*4+3] = 255; - cur[q*4+2] = cur[q*3+2]; - cur[q*4+1] = cur[q*3+1]; - cur[q*4+0] = cur[q*3+0]; - } - } - } +#undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n + 1 == out_n); +#define STBI__CASE(f) \ + case f: \ + for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, \ + cur += output_bytes, prior += output_bytes) \ + for (k = 0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { + cur[k] = raw[k]; + } + break; + STBI__CASE(STBI__F_sub) { + cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); + } + break; + STBI__CASE(STBI__F_up) { + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + } + break; + STBI__CASE(STBI__F_avg) { + cur[k] = STBI__BYTECAST(raw[k] + + ((prior[k] + cur[k - output_bytes]) >> 1)); + } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + + stbi__paeth(cur[k - output_bytes], prior[k], + prior[k - output_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { + cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); + } + break; + STBI__CASE(STBI__F_paeth_first) { + cur[k] = + STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); + } + break; } - } else if (depth == 16) { - // force the image data from big-endian to platform-native. - // this is done in a separate pass due to the decoding relying - // on the data being untouched, but could probably be done - // per-line during decode if care is taken. - stbi_uc *cur = a->out; - stbi__uint16 *cur16 = (stbi__uint16*)cur; - - for(i=0; i < x*y*out_n; ++i,cur16++,cur+=2) { - *cur16 = (cur[0] << 8) | cur[1]; +#undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride * j; // start at the beginning of the row again + for (i = 0; i < x; ++i, cur += output_bytes) { + cur[filter_bytes + 1] = 255; + } } - } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j = 0; j < y; ++j) { + stbi_uc *cur = a->out + stride * j; + stbi_uc *in = a->out + stride * j + x * out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common + // 8-bit path optimal at minimal cost for 1/2/4-bit png guarante byte + // alignment, if width is not multiple of 8/4/2 we'll decode dummy + // trailing data that will be skipped in the later loop + stbi_uc scale = (color == 0) + ? stbi__depth_scale_table[depth] + : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than + // desired. we can allocate enough data that this never writes out of + // memory, but it could also overwrite the next scanline. can it overwrite + // non-empty data on the next scanline? yes, consider 1-pixel-wide + // scanlines with 1-bit-per-pixel. so we need to explicitly clamp the + // final ones + + if (depth == 4) { + for (k = x * img_n; k >= 2; k -= 2, ++in) { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } else if (depth == 2) { + for (k = x * img_n; k >= 4; k -= 4, ++in) { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k = x * img_n; k >= 8; k -= 8, ++in) { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) + *cur++ = scale * ((*in >> 7)); + if (k > 1) + *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) + *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) + *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) + *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) + *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) + *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride * j; + if (img_n == 1) { + for (q = x - 1; q >= 0; --q) { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q = x - 1; q >= 0; --q) { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } + } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc *cur = a->out; + stbi__uint16 *cur16 = (stbi__uint16 *)cur; + + for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, + stbi__uint32 image_data_len, int out_n, + int depth, int color, int interlaced) { + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, + a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + for (p = 0; p < 7; ++p) { + int xorig[] = {0, 4, 0, 2, 0, 1, 0}; + int yorig[] = {0, 0, 4, 0, 2, 0, 1}; + int xspc[] = {8, 8, 4, 4, 2, 2, 1}; + int yspc[] = {8, 8, 8, 4, 4, 2, 2}; + int i, j, x, y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, + y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j = 0; j < y; ++j) { + for (i = 0; i < x; ++i) { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, + a->out + (j * x + i) * out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; - return 1; + return 1; } -static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) -{ - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; - stbi_uc *final; - int p; - if (!interlaced) - return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); - - // de-interlacing - final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); - for (p=0; p < 7; ++p) { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; - int i,j,x,y; - // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 - x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; - y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; - if (x && y) { - stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; - if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { - STBI_FREE(final); - return 0; - } - for (j=0; j < y; ++j) { - for (i=0; i < x; ++i) { - int out_y = j*yspc[p]+yorig[p]; - int out_x = i*xspc[p]+xorig[p]; - memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, - a->out + (j*x+i)*out_bytes, out_bytes); - } - } - STBI_FREE(a->out); - image_data += img_len; - image_data_len -= img_len; - } - } - a->out = final; +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; - return 1; -} + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); -static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - // compute color-based transparency, assuming we've - // already got 255 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); - - if (out_n == 2) { - for (i=0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 255); - p += 2; - } - } else { - for (i=0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi__uint16 *p = (stbi__uint16*) z->out; +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], + int out_n) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16 *)z->out; - // compute color-based transparency, assuming we've - // already got 65535 as the alpha value in the output - STBI_ASSERT(out_n == 2 || out_n == 4); + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); - if (out_n == 2) { - for (i = 0; i < pixel_count; ++i) { - p[1] = (p[0] == tc[0] ? 0 : 65535); - p += 2; - } - } else { - for (i = 0; i < pixel_count; ++i) { - if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) - p[3] = 0; - p += 4; - } - } - return 1; + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; } -static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) -{ - stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; - stbi_uc *p, *temp_out, *orig = a->out; - - p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); - if (p == NULL) return stbi__err("outofmem", "Out of memory"); - - // between here and free(out) below, exitting would leak - temp_out = p; - - if (pal_img_n == 3) { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p += 3; - } - } else { - for (i=0; i < pixel_count; ++i) { - int n = orig[i]*4; - p[0] = palette[n ]; - p[1] = palette[n+1]; - p[2] = palette[n+2]; - p[3] = palette[n+3]; - p += 4; - } - } - STBI_FREE(a->out); - a->out = temp_out; +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, + int pal_img_n) { + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); - STBI_NOTUSED(len); + // between here and free(out) below, exitting would leak + temp_out = p; - return 1; + if (pal_img_n == 3) { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p += 3; + } + } else { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p[3] = palette[n + 3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; } static int stbi__unpremultiply_on_load = 0; static int stbi__de_iphone_flag = 0; -STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) -{ - stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; +STBIDEF void +stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load = flag_true_if_should_unpremultiply; } -STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) -{ - stbi__de_iphone_flag = flag_true_if_should_convert; +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { + stbi__de_iphone_flag = flag_true_if_should_convert; } -static void stbi__de_iphone(stbi__png *z) -{ - stbi__context *s = z->s; - stbi__uint32 i, pixel_count = s->img_x * s->img_y; - stbi_uc *p = z->out; - - if (s->img_out_n == 3) { // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 3; +static void stbi__de_iphone(stbi__png *z) { + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i = 0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = (t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; } - } else { - STBI_ASSERT(s->img_out_n == 4); - if (stbi__unpremultiply_on_load) { - // convert bgr to rgb and unpremultiply - for (i=0; i < pixel_count; ++i) { - stbi_uc a = p[3]; - stbi_uc t = p[0]; - if (a) { - stbi_uc half = a / 2; - p[0] = (p[2] * 255 + half) / a; - p[1] = (p[1] * 255 + half) / a; - p[2] = ( t * 255 + half) / a; - } else { - p[0] = p[2]; - p[2] = t; - } - p += 4; - } + } else { + // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a, b, c, d) \ + (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + \ + (unsigned)(d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) { + stbi_uc palette[1024], pal_img_n = 0; + stbi_uc has_trans = 0, tc[3] = {0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; + int first = 1, k, interlace = 0, color = 0, is_iphone = 0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) + return 0; + + if (scan == STBI__SCAN_type) + return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { + int comp, filter; + if (!first) + return stbi__err("multiple IHDR", "Corrupt PNG"); + first = 0; + if (c.length != 13) + return stbi__err("bad IHDR len", "Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + z->depth = stbi__get8(s); + if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && + z->depth != 16) + return stbi__err("1/2/4/8/16-bit only", + "PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); + if (color > 6) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3 && z->depth == 16) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3) + pal_img_n = 3; + else if (color & 1) + return stbi__err("bad ctype", "Corrupt PNG"); + comp = stbi__get8(s); + if (comp) + return stbi__err("bad comp method", "Corrupt PNG"); + filter = stbi__get8(s); + if (filter) + return stbi__err("bad filter method", "Corrupt PNG"); + interlace = stbi__get8(s); + if (interlace > 1) + return stbi__err("bad interlace method", "Corrupt PNG"); + if (!s->img_x || !s->img_y) + return stbi__err("0-pixel image", "Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) + return stbi__err("too large", "Image too large to decode"); + if (scan == STBI__SCAN_header) + return 1; } else { - // convert bgr to rgb - for (i=0; i < pixel_count; ++i) { - stbi_uc t = p[0]; - p[0] = p[2]; - p[2] = t; - p += 4; - } + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) + return stbi__err("too large", "Corrupt PNG"); + // if SCAN_header, have to scan to see if we have a tRNS } - } -} - -#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + break; + } -static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) -{ - stbi_uc palette[1024], pal_img_n=0; - stbi_uc has_trans=0, tc[3]={0}; - stbi__uint16 tc16[3]; - stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; - int first=1,k,interlace=0, color=0, is_iphone=0; - stbi__context *s = z->s; - - z->expanded = NULL; - z->idata = NULL; - z->out = NULL; - - if (!stbi__check_png_header(s)) return 0; - - if (scan == STBI__SCAN_type) return 1; - - for (;;) { - stbi__pngchunk c = stbi__get_chunk_header(s); - switch (c.type) { - case STBI__PNG_TYPE('C','g','B','I'): - is_iphone = 1; - stbi__skip(s, c.length); - break; - case STBI__PNG_TYPE('I','H','D','R'): { - int comp,filter; - if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); - first = 0; - if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); - s->img_x = stbi__get32be(s); - s->img_y = stbi__get32be(s); - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); - color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); - if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); - comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); - filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); - interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); - if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); - if (!pal_img_n) { - s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); - if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); - if (scan == STBI__SCAN_header) return 1; - } else { - // if paletted, then pal_n is our final components, and - // img_n is # components to decompress/filter. - s->img_n = 1; - if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); - // if SCAN_header, have to scan to see if we have a tRNS - } - break; - } - - case STBI__PNG_TYPE('P','L','T','E'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); - pal_len = c.length / 3; - if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); - for (i=0; i < pal_len; ++i) { - palette[i*4+0] = stbi__get8(s); - palette[i*4+1] = stbi__get8(s); - palette[i*4+2] = stbi__get8(s); - palette[i*4+3] = 255; - } - break; - } - - case STBI__PNG_TYPE('t','R','N','S'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); - if (pal_img_n) { - if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } - if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); - if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); - pal_img_n = 4; - for (i=0; i < c.length; ++i) - palette[i*4+3] = stbi__get8(s); - } else { - if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); - if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); - has_trans = 1; - if (z->depth == 16) { - for (k = 0; k < s->img_n; ++k) tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is - } else { - for (k = 0; k < s->img_n; ++k) tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger - } - } - break; - } - - case STBI__PNG_TYPE('I','D','A','T'): { - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); - if (scan == STBI__SCAN_header) { s->img_n = pal_img_n; return 1; } - if ((int)(ioff + c.length) < (int)ioff) return 0; - if (ioff + c.length > idata_limit) { - stbi__uint32 idata_limit_old = idata_limit; - stbi_uc *p; - if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; - while (ioff + c.length > idata_limit) - idata_limit *= 2; - STBI_NOTUSED(idata_limit_old); - p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); - z->idata = p; - } - if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); - ioff += c.length; - break; - } - - case STBI__PNG_TYPE('I','E','N','D'): { - stbi__uint32 raw_len, bpl; - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if (scan != STBI__SCAN_load) return 1; - if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); - // initial guess for decoded data size to avoid unnecessary reallocs - bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component - raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; - z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); - if (z->expanded == NULL) return 0; // zlib should set error - STBI_FREE(z->idata); z->idata = NULL; - if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) - s->img_out_n = s->img_n+1; - else - s->img_out_n = s->img_n; - if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; - if (has_trans) { - if (z->depth == 16) { - if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; - } else { - if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; - } - } - if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) - stbi__de_iphone(z); - if (pal_img_n) { - // pal_img_n == 3 or 4 - s->img_n = pal_img_n; // record the actual colors we had - s->img_out_n = pal_img_n; - if (req_comp >= 3) s->img_out_n = req_comp; - if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) - return 0; - } else if (has_trans) { - // non-paletted image with tRNS -> source image has (constant) alpha - ++s->img_n; - } - STBI_FREE(z->expanded); z->expanded = NULL; - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - return 1; - } - - default: - // if critical, fail - if (first) return stbi__err("first not IHDR", "Corrupt PNG"); - if ((c.type & (1 << 29)) == 0) { - #ifndef STBI_NO_FAILURE_STRINGS - // not threadsafe - static char invalid_chunk[] = "XXXX PNG chunk not known"; - invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); - invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); - invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); - invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); - #endif - return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); - } - stbi__skip(s, c.length); - break; + case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256 * 3) + return stbi__err("invalid PLTE", "Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) + return stbi__err("invalid PLTE", "Corrupt PNG"); + for (i = 0; i < pal_len; ++i) { + palette[i * 4 + 0] = stbi__get8(s); + palette[i * 4 + 1] = stbi__get8(s); + palette[i * 4 + 2] = stbi__get8(s); + palette[i * 4 + 3] = 255; } - // end of PNG chunk, read and skip CRC - stbi__get32be(s); - } -} + break; + } -static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) -{ - void *result=NULL; - if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); - if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { - if (p->depth <= 8) - ri->bits_per_channel = 8; - else if (p->depth == 16) - ri->bits_per_channel = 16; - else - return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); - result = p->out; - p->out = NULL; - if (req_comp && req_comp != p->s->img_out_n) { - if (ri->bits_per_channel == 8) - result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - else - result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); - p->s->img_out_n = req_comp; - if (result == NULL) return result; + case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) + return stbi__err("tRNS after IDAT", "Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { + s->img_n = 4; + return 1; + } + if (pal_len == 0) + return stbi__err("tRNS before PLTE", "Corrupt PNG"); + if (c.length > pal_len) + return stbi__err("bad tRNS len", "Corrupt PNG"); + pal_img_n = 4; + for (i = 0; i < c.length; ++i) + palette[i * 4 + 3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) + return stbi__err("tRNS with alpha", "Corrupt PNG"); + if (c.length != (stbi__uint32)s->img_n * 2) + return stbi__err("bad tRNS len", "Corrupt PNG"); + has_trans = 1; + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * + stbi__depth_scale_table[z->depth]; // non 8-bit images will + // be larger + } } - *x = p->s->img_x; - *y = p->s->img_y; - if (n) *n = p->s->img_n; - } - STBI_FREE(p->out); p->out = NULL; - STBI_FREE(p->expanded); p->expanded = NULL; - STBI_FREE(p->idata); p->idata = NULL; - - return result; -} - -static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi__png p; - p.s = s; - return stbi__do_png(&p, x,y,comp,req_comp, ri); -} - -static int stbi__png_test(stbi__context *s) -{ - int r; - r = stbi__check_png_header(s); - stbi__rewind(s); - return r; -} + break; + } -static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) -{ - if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { - stbi__rewind( p->s ); - return 0; - } - if (x) *x = p->s->img_x; - if (y) *y = p->s->img_y; - if (comp) *comp = p->s->img_n; - return 1; -} + case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) + return stbi__err("no PLTE", "Corrupt PNG"); + if (scan == STBI__SCAN_header) { + s->img_n = pal_img_n; + return 1; + } + if ((int)(ioff + c.length) < (int)ioff) + return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) + idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, + idata_limit); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata + ioff, c.length)) + return stbi__err("outofdata", "Corrupt PNG"); + ioff += c.length; + break; + } -static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__png p; - p.s = s; - return stbi__png_info_raw(&p, x, y, comp); -} + case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + stbi__uint32 raw_len, bpl; + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) + return 1; + if (z->idata == NULL) + return stbi__err("no IDAT", "Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag( + (char *)z->idata, ioff, raw_len, (int *)&raw_len, !is_iphone); + if (z->expanded == NULL) + return 0; // zlib should set error + STBI_FREE(z->idata); + z->idata = NULL; + if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || + has_trans) + s->img_out_n = s->img_n + 1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, + z->depth, color, interlace)) + return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) + return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) + return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) + s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); + z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } -static int stbi__png_is16(stbi__context *s) -{ - stbi__png p; - p.s = s; - if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) - return 0; - if (p.depth != 16) { - stbi__rewind(p.s); - return 0; - } - return 1; + default: + // if critical, fail + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { +#ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); +#endif + return stbi__err(invalid_chunk, + "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, + stbi__result_info *ri) { + void *result = NULL; + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", + "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, + req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) + return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) + *n = p->s->img_n; + } + STBI_FREE(p->out); + p->out = NULL; + STBI_FREE(p->expanded); + p->expanded = NULL; + STBI_FREE(p->idata); + p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi__png p; + p.s = s; + return stbi__do_png(&p, x, y, comp, req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) { + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) { + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind(p->s); + return 0; + } + if (x) + *x = p->s->img_x; + if (y) + *y = p->s->img_y; + if (comp) + *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) { + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) { + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; } #endif // Microsoft/Windows BMP image #ifndef STBI_NO_BMP -static int stbi__bmp_test_raw(stbi__context *s) -{ - int r; - int sz; - if (stbi__get8(s) != 'B') return 0; - if (stbi__get8(s) != 'M') return 0; - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - stbi__get32le(s); // discard data offset - sz = stbi__get32le(s); - r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); - return r; -} - -static int stbi__bmp_test(stbi__context *s) -{ - int r = stbi__bmp_test_raw(s); - stbi__rewind(s); - return r; +static int stbi__bmp_test_raw(stbi__context *s) { + int r; + int sz; + if (stbi__get8(s) != 'B') + return 0; + if (stbi__get8(s) != 'M') + return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) { + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; } - // returns 0..31 for the highest set bit -static int stbi__high_bit(unsigned int z) -{ - int n=0; - if (z == 0) return -1; - if (z >= 0x10000) { n += 16; z >>= 16; } - if (z >= 0x00100) { n += 8; z >>= 8; } - if (z >= 0x00010) { n += 4; z >>= 4; } - if (z >= 0x00004) { n += 2; z >>= 2; } - if (z >= 0x00002) { n += 1;/* >>= 1;*/ } - return n; +static int stbi__high_bit(unsigned int z) { + int n = 0; + if (z == 0) + return -1; + if (z >= 0x10000) { + n += 16; + z >>= 16; + } + if (z >= 0x00100) { + n += 8; + z >>= 8; + } + if (z >= 0x00010) { + n += 4; + z >>= 4; + } + if (z >= 0x00004) { + n += 2; + z >>= 2; + } + if (z >= 0x00002) { + n += 1; /* >>= 1;*/ + } + return n; } -static int stbi__bitcount(unsigned int a) -{ - a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 - a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 - a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits - a = (a + (a >> 8)); // max 16 per 8 bits - a = (a + (a >> 16)); // max 32 per 8 bits - return a & 0xff; +static int stbi__bitcount(unsigned int a) { + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; } // extract an arbitrarily-aligned N-bit value (N=bits) // from v, and then make it 8-bits long and fractionally // extend it to full full range. -static int stbi__shiftsigned(unsigned int v, int shift, int bits) -{ - static unsigned int mul_table[9] = { +static int stbi__shiftsigned(unsigned int v, int shift, int bits) { + static unsigned int mul_table[9] = { 0, - 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, - 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, - }; - static unsigned int shift_table[9] = { - 0, 0,0,1,0,2,4,6,0, - }; - if (shift < 0) - v <<= -shift; - else - v >>= shift; - STBI_ASSERT(v < 256); - v >>= (8-bits); - STBI_ASSERT(bits >= 0 && bits <= 8); - return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; -} - -typedef struct -{ - int bpp, offset, hsz; - unsigned int mr,mg,mb,ma, all_a; - int extra_read; + 0xff /*0b11111111*/, + 0x55 /*0b01010101*/, + 0x49 /*0b01001001*/, + 0x11 /*0b00010001*/, + 0x21 /*0b00100001*/, + 0x41 /*0b01000001*/, + 0x81 /*0b10000001*/, + 0x01 /*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0, 0, 1, 0, 2, 4, 6, 0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8 - bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct { + int bpp, offset, hsz; + unsigned int mr, mg, mb, ma, all_a; + int extra_read; } stbi__bmp_data; -static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) -{ - int hsz; - if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); - stbi__get32le(s); // discard filesize - stbi__get16le(s); // discard reserved - stbi__get16le(s); // discard reserved - info->offset = stbi__get32le(s); - info->hsz = hsz = stbi__get32le(s); - info->mr = info->mg = info->mb = info->ma = 0; - info->extra_read = 14; - - if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); - - if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); - if (hsz == 12) { - s->img_x = stbi__get16le(s); - s->img_y = stbi__get16le(s); - } else { - s->img_x = stbi__get32le(s); - s->img_y = stbi__get32le(s); - } - if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); - info->bpp = stbi__get16le(s); - if (hsz != 12) { - int compress = stbi__get32le(s); - if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); - stbi__get32le(s); // discard sizeof - stbi__get32le(s); // discard hres - stbi__get32le(s); // discard vres - stbi__get32le(s); // discard colorsused - stbi__get32le(s); // discard max important - if (hsz == 40 || hsz == 56) { - if (hsz == 56) { - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - stbi__get32le(s); - } - if (info->bpp == 16 || info->bpp == 32) { - if (compress == 0) { - if (info->bpp == 32) { - info->mr = 0xffu << 16; - info->mg = 0xffu << 8; - info->mb = 0xffu << 0; - info->ma = 0xffu << 24; - info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 - } else { - info->mr = 31u << 10; - info->mg = 31u << 5; - info->mb = 31u << 0; - } - } else if (compress == 3) { - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->extra_read += 12; - // not documented, but generated by photoshop and handled by mspaint - if (info->mr == info->mg && info->mg == info->mb) { - // ?!?!? - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else - return stbi__errpuc("bad BMP", "bad BMP"); - } - } else { - int i; - if (hsz != 108 && hsz != 124) +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) { + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') + return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) + return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) + return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) + return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) + return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha + // channel but it was all 0 + } else { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? return stbi__errpuc("bad BMP", "bad BMP"); - info->mr = stbi__get32le(s); - info->mg = stbi__get32le(s); - info->mb = stbi__get32le(s); - info->ma = stbi__get32le(s); - stbi__get32le(s); // discard color space - for (i=0; i < 12; ++i) - stbi__get32le(s); // discard color space parameters - if (hsz == 124) { - stbi__get32le(s); // discard rendering intent - stbi__get32le(s); // discard offset of profile data - stbi__get32le(s); // discard size of profile data - stbi__get32le(s); // discard reserved - } + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); } - } - return (void *) 1; -} - - -static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - unsigned int mr=0,mg=0,mb=0,ma=0, all_a; - stbi_uc pal[256][4]; - int psize=0,i,j,width; - int flip_vertically, pad, target; - stbi__bmp_data info; - STBI_NOTUSED(ri); - - info.all_a = 255; - if (stbi__bmp_parse_header(s, &info) == NULL) - return NULL; // error code already set - - flip_vertically = ((int) s->img_y) > 0; - s->img_y = abs((int) s->img_y); - - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - mr = info.mr; - mg = info.mg; - mb = info.mb; - ma = info.ma; - all_a = info.all_a; - - if (info.hsz == 12) { - if (info.bpp < 24) - psize = (info.offset - info.extra_read - 24) / 3; - } else { - if (info.bpp < 16) - psize = (info.offset - info.extra_read - info.hsz) >> 2; - } - if (psize == 0) { - STBI_ASSERT(info.offset == s->callback_already_read + (int) (s->img_buffer - s->img_buffer_original)); - if (info.offset != s->callback_already_read + (s->img_buffer - s->buffer_start)) { - return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + stbi__get32le(s); // discard color space + for (i = 0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved } - } - - if (info.bpp == 24 && ma == 0xff000000) - s->img_n = 3; - else - s->img_n = ma ? 4 : 3; - if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 - target = req_comp; - else - target = s->img_n; // if they want monochrome, we'll post-convert - - // sanity-check size - if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "Corrupt BMP"); - - out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - if (info.bpp < 16) { - int z=0; - if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } - for (i=0; i < psize; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - if (info.hsz != 12) stbi__get8(s); - pal[i][3] = 255; + } + } + return (void *)1; +} + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; + stbi_uc pal[256][4]; + int psize = 0, i, j, width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int)s->img_y) > 0; + s->img_y = abs((int)s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + STBI_ASSERT(info.offset == + s->callback_already_read + + (int)(s->img_buffer - s->img_buffer_original)); + if (info.offset != + s->callback_already_read + (s->img_buffer - s->buffer_start)) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z = 0; + if (psize == 0 || psize > 256) { + STBI_FREE(out); + return stbi__errpuc("invalid", "Corrupt BMP"); + } + for (i = 0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) + stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - + psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) + width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) + width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) + width = s->img_x; + else { + STBI_FREE(out); + return stbi__errpuc("bad bpp", "Corrupt BMP"); + } + pad = (-width) & 3; + if (info.bpp == 1) { + for (j = 0; j < (int)s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i = 0; i < (int)s->img_x; ++i) { + int color = (v >> bit_offset) & 0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + if ((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); } - stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); - if (info.bpp == 1) width = (s->img_x + 7) >> 3; - else if (info.bpp == 4) width = (s->img_x + 1) >> 1; - else if (info.bpp == 8) width = s->img_x; - else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } - pad = (-width)&3; - if (info.bpp == 1) { - for (j=0; j < (int) s->img_y; ++j) { - int bit_offset = 7, v = stbi__get8(s); - for (i=0; i < (int) s->img_x; ++i) { - int color = (v>>bit_offset)&0x1; - out[z++] = pal[color][0]; - out[z++] = pal[color][1]; - out[z++] = pal[color][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - if((--bit_offset) < 0) { - bit_offset = 7; - v = stbi__get8(s); - } - } - stbi__skip(s, pad); - } - } else { - for (j=0; j < (int) s->img_y; ++j) { - for (i=0; i < (int) s->img_x; i += 2) { - int v=stbi__get8(s),v2=0; - if (info.bpp == 4) { - v2 = v & 15; - v >>= 4; - } - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - if (i+1 == (int) s->img_x) break; - v = (info.bpp == 8) ? stbi__get8(s) : v2; - out[z++] = pal[v][0]; - out[z++] = pal[v][1]; - out[z++] = pal[v][2]; - if (target == 4) out[z++] = 255; - } - stbi__skip(s, pad); - } + } else { + for (j = 0; j < (int)s->img_y; ++j) { + for (i = 0; i < (int)s->img_x; i += 2) { + int v = stbi__get8(s), v2 = 0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + } + stbi__skip(s, pad); } - } else { - int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; - int z = 0; - int easy=0; - stbi__skip(s, info.offset - info.extra_read - info.hsz); - if (info.bpp == 24) width = 3 * s->img_x; - else if (info.bpp == 16) width = 2*s->img_x; - else /* bpp = 32 and pad = 0 */ width=0; - pad = (-width) & 3; - if (info.bpp == 24) { - easy = 1; - } else if (info.bpp == 32) { - if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) - easy = 2; + } + } else { + int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, + bcount = 0, acount = 0; + int z = 0; + int easy = 0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) + width = 3 * s->img_x; + else if (info.bpp == 16) + width = 2 * s->img_x; + else /* bpp = 32 and pad = 0 */ + width = 0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - if (!easy) { - if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } - // right shift amt to put high bit in position #7 - rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); - gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); - bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); - ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); - if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr) - 7; + rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg) - 7; + gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb) - 7; + bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma) - 7; + acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); } - for (j=0; j < (int) s->img_y; ++j) { - if (easy) { - for (i=0; i < (int) s->img_x; ++i) { - unsigned char a; - out[z+2] = stbi__get8(s); - out[z+1] = stbi__get8(s); - out[z+0] = stbi__get8(s); - z += 3; - a = (easy == 2 ? stbi__get8(s) : 255); - all_a |= a; - if (target == 4) out[z++] = a; - } - } else { - int bpp = info.bpp; - for (i=0; i < (int) s->img_x; ++i) { - stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); - unsigned int a; - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); - out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); - a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); - all_a |= a; - if (target == 4) out[z++] = STBI__BYTECAST(a); - } - } - stbi__skip(s, pad); + } + for (j = 0; j < (int)s->img_y; ++j) { + if (easy) { + for (i = 0; i < (int)s->img_x; ++i) { + unsigned char a; + out[z + 2] = stbi__get8(s); + out[z + 1] = stbi__get8(s); + out[z + 0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) + out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i = 0; i < (int)s->img_x; ++i) { + stbi__uint32 v = + (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) + out[z++] = STBI__BYTECAST(a); + } } - } - - // if alpha channel is all 0s, replace with all 255s - if (target == 4 && all_a == 0) - for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) - out[i] = 255; - - if (flip_vertically) { - stbi_uc t; - for (j=0; j < (int) s->img_y>>1; ++j) { - stbi_uc *p1 = out + j *s->img_x*target; - stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; - for (i=0; i < (int) s->img_x*target; ++i) { - t = p1[i]; p1[i] = p2[i]; p2[i] = t; - } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j = 0; j < (int)s->img_y >> 1; ++j) { + stbi_uc *p1 = out + j * s->img_x * target; + stbi_uc *p2 = out + (s->img_y - 1 - j) * s->img_x * target; + for (i = 0; i < (int)s->img_x * target; ++i) { + t = p1[i]; + p1[i] = p2[i]; + p2[i] = t; } - } + } + } - if (req_comp && req_comp != target) { - out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; - return out; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + return out; } #endif @@ -5555,592 +6271,625 @@ static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req // by Jonathan Dummer #ifndef STBI_NO_TGA // returns STBI_rgb or whatever, 0 on error -static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) -{ - // only RGB or RGBA (incl. 16bit) or grey allowed - if (is_rgb16) *is_rgb16 = 0; - switch(bits_per_pixel) { - case 8: return STBI_grey; - case 16: if(is_grey) return STBI_grey_alpha; - // fallthrough - case 15: if(is_rgb16) *is_rgb16 = 1; - return STBI_rgb; - case 24: // fallthrough - case 32: return bits_per_pixel/8; - default: return 0; - } -} - -static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) -{ - int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; - int sz, tga_colormap_type; - stbi__get8(s); // discard Offset - tga_colormap_type = stbi__get8(s); // colormap type - if( tga_colormap_type > 1 ) { - stbi__rewind(s); - return 0; // only RGB or indexed allowed - } - tga_image_type = stbi__get8(s); // image type - if ( tga_colormap_type == 1 ) { // colormapped (paletted) image - if (tga_image_type != 1 && tga_image_type != 9) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { - stbi__rewind(s); - return 0; - } - stbi__skip(s,4); // skip image x and y origin - tga_colormap_bpp = sz; - } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE - if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { - stbi__rewind(s); - return 0; // only RGB or grey allowed, +/- RLE - } - stbi__skip(s,9); // skip colormap specification and image x/y origin - tga_colormap_bpp = 0; - } - tga_w = stbi__get16le(s); - if( tga_w < 1 ) { - stbi__rewind(s); - return 0; // test width +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int *is_rgb16) { + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) + *is_rgb16 = 0; + switch (bits_per_pixel) { + case 8: + return STBI_grey; + case 16: + if (is_grey) + return STBI_grey_alpha; + // fallthrough + case 15: + if (is_rgb16) + *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: + return bits_per_pixel / 8; + default: + return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) { + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, + tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if (tga_colormap_type > 1) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; } - tga_h = stbi__get16le(s); - if( tga_h < 1 ) { - stbi__rewind(s); - return 0; // test height + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__rewind(s); + return 0; } - tga_bits_per_pixel = stbi__get8(s); // bits per pixel - stbi__get8(s); // ignore alpha bits - if (tga_colormap_bpp != 0) { - if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { - // when using a colormap, tga_bits_per_pixel is the size of the indexes - // I don't think anything but 8 or 16bit indexes makes sense - stbi__rewind(s); - return 0; - } - tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); - } else { - tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + stbi__skip(s, 4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ((tga_image_type != 2) && (tga_image_type != 3) && + (tga_image_type != 10) && (tga_image_type != 11)) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE } - if(!tga_comp) { + stbi__skip(s, 9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if (tga_w < 1) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if (tga_h < 1) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense stbi__rewind(s); return 0; } - if (x) *x = tga_w; - if (y) *y = tga_h; - if (comp) *comp = tga_comp; - return 1; // seems to have passed everything -} - -static int stbi__tga_test(stbi__context *s) -{ - int res = 0; - int sz, tga_color_type; - stbi__get8(s); // discard Offset - tga_color_type = stbi__get8(s); // color type - if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed - sz = stbi__get8(s); // image type - if ( tga_color_type == 1 ) { // colormapped (paletted) image - if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 - stbi__skip(s,4); // skip index of first colormap entry and number of entries - sz = stbi__get8(s); // check bits per palette color entry - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - stbi__skip(s,4); // skip image x and y origin - } else { // "normal" image w/o colormap - if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE - stbi__skip(s,9); // skip colormap specification and image x/y origin - } - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width - if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height - sz = stbi__get8(s); // bits per pixel - if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index - if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; - - res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp( + tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), + NULL); + } + if (!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) + *x = tga_w; + if (y) + *y = tga_h; + if (comp) + *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) { + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if (tga_color_type > 1) + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image + if (sz != 1 && sz != 9) + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, + 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin + } + if (stbi__get16le(s) < 1) + goto errorEnd; // test width + if (stbi__get16le(s) < 1) + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) + goto errorEnd; // for colormapped images, bpp is size of an index + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead + // of 0 errorEnd: - stbi__rewind(s); - return res; + stbi__rewind(s); + return res; } // read 16bit value and convert to 24bit RGB -static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) -{ - stbi__uint16 px = (stbi__uint16)stbi__get16le(s); - stbi__uint16 fiveBitMask = 31; - // we have 3 channels with 5bits each - int r = (px >> 10) & fiveBitMask; - int g = (px >> 5) & fiveBitMask; - int b = px & fiveBitMask; - // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later - out[0] = (stbi_uc)((r * 255)/31); - out[1] = (stbi_uc)((g * 255)/31); - out[2] = (stbi_uc)((b * 255)/31); - - // some people claim that the most significant bit might be used for alpha - // (possibly if an alpha-bit is set in the "image descriptor byte") - // but that only made 16bit test images completely translucent.. - // so let's treat all 15 and 16bit TGAs as RGB with no alpha. -} - -static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - // read in the TGA header stuff - int tga_offset = stbi__get8(s); - int tga_indexed = stbi__get8(s); - int tga_image_type = stbi__get8(s); - int tga_is_RLE = 0; - int tga_palette_start = stbi__get16le(s); - int tga_palette_len = stbi__get16le(s); - int tga_palette_bits = stbi__get8(s); - int tga_x_origin = stbi__get16le(s); - int tga_y_origin = stbi__get16le(s); - int tga_width = stbi__get16le(s); - int tga_height = stbi__get16le(s); - int tga_bits_per_pixel = stbi__get8(s); - int tga_comp, tga_rgb16=0; - int tga_inverted = stbi__get8(s); - // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) - // image data - unsigned char *tga_data; - unsigned char *tga_palette = NULL; - int i, j; - unsigned char raw_data[4] = {0}; - int RLE_count = 0; - int RLE_repeating = 0; - int read_next_pixel = 1; - STBI_NOTUSED(ri); - STBI_NOTUSED(tga_x_origin); // @TODO - STBI_NOTUSED(tga_y_origin); // @TODO - - if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // do a tiny bit of precessing - if ( tga_image_type >= 8 ) - { - tga_image_type -= 8; - tga_is_RLE = 1; - } - tga_inverted = 1 - ((tga_inverted >> 5) & 1); - - // If I'm paletted, then I'll use the number of bits from the palette - if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); - else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); - - if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency - return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); - - // tga info - *x = tga_width; - *y = tga_height; - if (comp) *comp = tga_comp; - - if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) - return stbi__errpuc("too large", "Corrupt TGA"); - - tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); - if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); - - // skip to the data's starting position (offset usually = 0) - stbi__skip(s, tga_offset ); - - if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { - for (i=0; i < tga_height; ++i) { - int row = tga_inverted ? tga_height -i - 1 : i; - stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; - stbi__getn(s, tga_row, tga_width * tga_comp); - } - } else { - // do I need to load a palette? - if ( tga_indexed) - { - if (tga_palette_len == 0) { /* you have to have at least one entry! */ - STBI_FREE(tga_data); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } - - // any data to skip? (offset usually = 0) - stbi__skip(s, tga_palette_start ); - // load the palette - tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); - if (!tga_palette) { - STBI_FREE(tga_data); - return stbi__errpuc("outofmem", "Out of memory"); - } - if (tga_rgb16) { - stbi_uc *pal_entry = tga_palette; - STBI_ASSERT(tga_comp == STBI_rgb); - for (i=0; i < tga_palette_len; ++i) { - stbi__tga_read_rgb16(s, pal_entry); - pal_entry += tga_comp; - } - } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { - STBI_FREE(tga_data); - STBI_FREE(tga_palette); - return stbi__errpuc("bad palette", "Corrupt TGA"); - } +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc *out) { + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be + // swapped later + out[0] = (stbi_uc)((r * 255) / 31); + out[1] = (stbi_uc)((g * 255) / 31); + out[2] = (stbi_uc)((b * 255) / 31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16 = 0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused + // (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // do a tiny bit of precessing + if (tga_image_type >= 8) { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if (tga_indexed) + tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), + &tga_rgb16); + + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have + // ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) + *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = + (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) + return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset); + + if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { + for (i = 0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height - i - 1 : i; + stbi_uc *tga_row = tga_data + row * tga_width * tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if (tga_indexed) { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // load the data - for (i=0; i < tga_width * tga_height; ++i) - { - // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? - if ( tga_is_RLE ) - { - if ( RLE_count == 0 ) - { - // yep, get the next byte as a RLE command - int RLE_cmd = stbi__get8(s); - RLE_count = 1 + (RLE_cmd & 127); - RLE_repeating = RLE_cmd >> 7; - read_next_pixel = 1; - } else if ( !RLE_repeating ) - { - read_next_pixel = 1; - } - } else - { - read_next_pixel = 1; - } - // OK, if I need to read a pixel, do it now - if ( read_next_pixel ) - { - // load however much data we did have - if ( tga_indexed ) - { - // read in index, then perform the lookup - int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); - if ( pal_idx >= tga_palette_len ) { - // invalid index - pal_idx = 0; - } - pal_idx *= tga_comp; - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = tga_palette[pal_idx+j]; - } - } else if(tga_rgb16) { - STBI_ASSERT(tga_comp == STBI_rgb); - stbi__tga_read_rgb16(s, raw_data); - } else { - // read in the data raw - for (j = 0; j < tga_comp; ++j) { - raw_data[j] = stbi__get8(s); - } - } - // clear the reading flag for the next pixel - read_next_pixel = 0; - } // end of reading a pixel - - // copy data - for (j = 0; j < tga_comp; ++j) - tga_data[i*tga_comp+j] = raw_data[j]; - // in case we're in RLE mode, keep counting down - --RLE_count; + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start); + // load the palette + tga_palette = + (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); } - // do I need to invert the image? - if ( tga_inverted ) - { - for (j = 0; j*2 < tga_height; ++j) - { - int index1 = j * tga_width * tga_comp; - int index2 = (tga_height - 1 - j) * tga_width * tga_comp; - for (i = tga_width * tga_comp; i > 0; --i) - { - unsigned char temp = tga_data[index1]; - tga_data[index1] = tga_data[index2]; - tga_data[index2] = temp; - ++index1; - ++index2; - } - } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i = 0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); } - // clear my palette, if I had one - if ( tga_palette != NULL ) - { - STBI_FREE( tga_palette ); + } + // load the data + for (i = 0; i < tga_width * tga_height; ++i) { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if (tga_is_RLE) { + if (RLE_count == 0) { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if (!RLE_repeating) { + read_next_pixel = 1; + } + } else { + read_next_pixel = 1; } - } + // OK, if I need to read a pixel, do it now + if (read_next_pixel) { + // load however much data we did have + if (tga_indexed) { + // read in index, then perform the lookup + int pal_idx = + (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if (pal_idx >= tga_palette_len) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx + j]; + } + } else if (tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel - // swap RGB - if the source data was RGB16, it already is in the right order - if (tga_comp >= 3 && !tga_rgb16) - { - unsigned char* tga_pixel = tga_data; - for (i=0; i < tga_width * tga_height; ++i) - { - unsigned char temp = tga_pixel[0]; - tga_pixel[0] = tga_pixel[2]; - tga_pixel[2] = temp; - tga_pixel += tga_comp; + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i * tga_comp + j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if (tga_inverted) { + for (j = 0; j * 2 < tga_height; ++j) { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } } - } + } + // clear my palette, if I had one + if (tga_palette != NULL) { + STBI_FREE(tga_palette); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) { + unsigned char *tga_pixel = tga_data; + for (i = 0; i < tga_width * tga_height; ++i) { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } - // convert to target component count - if (req_comp && req_comp != tga_comp) - tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, + tga_height); - // the things I do to get rid of an error message, and yet keep - // Microsoft's C compilers happy... [8^( - tga_palette_start = tga_palette_len = tga_palette_bits = - tga_x_origin = tga_y_origin = 0; - STBI_NOTUSED(tga_palette_start); - // OK, done - return tga_data; + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = + tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; } #endif // ************************************************************************************************* -// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, +// tweaked by STB #ifndef STBI_NO_PSD -static int stbi__psd_test(stbi__context *s) -{ - int r = (stbi__get32be(s) == 0x38425053); - stbi__rewind(s); - return r; -} - -static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) -{ - int count, nleft, len; - - count = 0; - while ((nleft = pixelCount - count) > 0) { - len = stbi__get8(s); - if (len == 128) { - // No-op. - } else if (len < 128) { - // Copy next len+1 bytes literally. - len++; - if (len > nleft) return 0; // corrupt data - count += len; - while (len) { - *p = stbi__get8(s); - p += 4; - len--; - } - } else if (len > 128) { - stbi_uc val; - // Next -len+1 bytes in the dest are replicated from next source byte. - // (Interpret len as a negative 8-bit int.) - len = 257 - len; - if (len > nleft) return 0; // corrupt data - val = stbi__get8(s); - count += len; - while (len) { - *p = val; - p += 4; - len--; - } +static int stbi__psd_test(stbi__context *s) { + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) { + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) + return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; } - } - - return 1; -} + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) + return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri, int bpc) { + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w, h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", + "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", + "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for + // other modes.) + stbi__skip(s, stbi__get32be(s)); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s)); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s)); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", + "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *)stbi__malloc(4 * w * h); + + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w * h; + + // Initialize the data to zero. + // memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes + // literally. Else if n is between -127 and -1 inclusive, copy the next + // byte -n+1 times. Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row + // in the data, which we're going to just skip. + stbi__skip(s, h * channelCount * 2); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out + channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } -static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) -{ - int pixelCount; - int channelCount, compression; - int channel, i; - int bitdepth; - int w,h; - stbi_uc *out; - STBI_NOTUSED(ri); - - // Check identifier - if (stbi__get32be(s) != 0x38425053) // "8BPS" - return stbi__errpuc("not PSD", "Corrupt PSD image"); - - // Check file type version. - if (stbi__get16be(s) != 1) - return stbi__errpuc("wrong version", "Unsupported version of PSD image"); - - // Skip 6 reserved bytes. - stbi__skip(s, 6 ); - - // Read the number of channels (R, G, B, A, etc). - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) - return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); - - // Read the rows and columns of the image. - h = stbi__get32be(s); - w = stbi__get32be(s); - - if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - - // Make sure the depth is 8 bits. - bitdepth = stbi__get16be(s); - if (bitdepth != 8 && bitdepth != 16) - return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); - - // Make sure the color mode is RGB. - // Valid options are: - // 0: Bitmap - // 1: Grayscale - // 2: Indexed color - // 3: RGB color - // 4: CMYK color - // 7: Multichannel - // 8: Duotone - // 9: Lab color - if (stbi__get16be(s) != 3) - return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); - - // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) - stbi__skip(s,stbi__get32be(s) ); - - // Skip the image resources. (resolution, pen tool paths, etc) - stbi__skip(s, stbi__get32be(s) ); - - // Skip the reserved data. - stbi__skip(s, stbi__get32be(s) ); - - // Find out if the data is compressed. - // Known values: - // 0: no compression - // 1: RLE compressed - compression = stbi__get16be(s); - if (compression > 1) - return stbi__errpuc("bad compression", "PSD has an unknown compression format"); - - // Check size - if (!stbi__mad3sizes_valid(4, w, h, 0)) - return stbi__errpuc("too large", "Corrupt PSD"); - - // Create the destination image. - - if (!compression && bitdepth == 16 && bpc == 16) { - out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); - ri->bits_per_channel = 16; - } else - out = (stbi_uc *) stbi__malloc(4 * w*h); - - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - pixelCount = w*h; - - // Initialize the data to zero. - //memset( out, 0, pixelCount * 4 ); - - // Finally, the image data. - if (compression) { - // RLE as used by .PSD and .TIFF - // Loop until you get the number of unpacked bytes you are expecting: - // Read the next source byte into n. - // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. - // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. - // Else if n is 128, noop. - // Endloop - - // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, - // which we're going to just skip. - stbi__skip(s, h * channelCount * 2 ); - - // Read the RLE data by channel. - for (channel = 0; channel < 4; channel++) { - stbi_uc *p; - - p = out+channel; - if (channel >= channelCount) { - // Fill this channel with default data. + } else { + // We're at the raw image data. It's each channel in order (Red, Green, + // Blue, Alpha, ...) where each channel consists of an 8-bit (or 16-bit) + // value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out + channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *)out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16)stbi__get16be(s); + } else { + stbi_uc *p = out + channel; + if (bitdepth == 16) { // input bpc for (i = 0; i < pixelCount; i++, p += 4) - *p = (channel == 3 ? 255 : 0); - } else { - // Read the RLE data. - if (!stbi__psd_decode_rle(s, p, pixelCount)) { - STBI_FREE(out); - return stbi__errpuc("corrupt", "bad RLE data"); - } - } + *p = (stbi_uc)(stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } } - - } else { - // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) - // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. - - // Read the data by channel. - for (channel = 0; channel < 4; channel++) { - if (channel >= channelCount) { - // Fill this channel with default data. - if (bitdepth == 16 && bpc == 16) { - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - stbi__uint16 val = channel == 3 ? 65535 : 0; - for (i = 0; i < pixelCount; i++, q += 4) - *q = val; - } else { - stbi_uc *p = out+channel; - stbi_uc val = channel == 3 ? 255 : 0; - for (i = 0; i < pixelCount; i++, p += 4) - *p = val; - } - } else { - if (ri->bits_per_channel == 16) { // output bpc - stbi__uint16 *q = ((stbi__uint16 *) out) + channel; - for (i = 0; i < pixelCount; i++, q += 4) - *q = (stbi__uint16) stbi__get16be(s); - } else { - stbi_uc *p = out+channel; - if (bitdepth == 16) { // input bpc - for (i = 0; i < pixelCount; i++, p += 4) - *p = (stbi_uc) (stbi__get16be(s) >> 8); - } else { - for (i = 0; i < pixelCount; i++, p += 4) - *p = stbi__get8(s); - } - } - } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i = 0; i < w * h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *)out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); + pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); + pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); + } } - } - - // remove weird white matte from PSD - if (channelCount >= 4) { - if (ri->bits_per_channel == 16) { - for (i=0; i < w*h; ++i) { - stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; - if (pixel[3] != 0 && pixel[3] != 65535) { - float a = pixel[3] / 65535.0f; - float ra = 1.0f / a; - float inv_a = 65535.0f * (1 - ra); - pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); - pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); - pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); - } - } - } else { - for (i=0; i < w*h; ++i) { - unsigned char *pixel = out + 4*i; - if (pixel[3] != 0 && pixel[3] != 255) { - float a = pixel[3] / 255.0f; - float ra = 1.0f / a; - float inv_a = 255.0f * (1 - ra); - pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); - pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); - pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); - } - } + } else { + for (i = 0; i < w * h; ++i) { + unsigned char *pixel = out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); + pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); + pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); + } } - } - - // convert to desired output format - if (req_comp && req_comp != 4) { - if (ri->bits_per_channel == 16) - out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); - else - out = stbi__convert_format(out, 4, req_comp, w, h); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - - if (comp) *comp = 4; - *y = h; - *x = w; - - return out; + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, + w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + if (comp) + *comp = 4; + *y = h; + *x = w; + + return out; } #endif @@ -6152,215 +6901,222 @@ static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req // See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ #ifndef STBI_NO_PIC -static int stbi__pic_is4(stbi__context *s,const char *str) -{ - int i; - for (i=0; i<4; ++i) - if (stbi__get8(s) != (stbi_uc)str[i]) - return 0; +static int stbi__pic_is4(stbi__context *s, const char *str) { + int i; + for (i = 0; i < 4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; - return 1; + return 1; } -static int stbi__pic_test_core(stbi__context *s) -{ - int i; +static int stbi__pic_test_core(stbi__context *s) { + int i; - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) - return 0; + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) + return 0; - for(i=0;i<84;++i) - stbi__get8(s); + for (i = 0; i < 84; ++i) + stbi__get8(s); - if (!stbi__pic_is4(s,"PICT")) - return 0; + if (!stbi__pic_is4(s, "PICT")) + return 0; - return 1; + return 1; } -typedef struct -{ - stbi_uc size,type,channel; +typedef struct { + stbi_uc size, type, channel; } stbi__pic_packet; -static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) -{ - int mask=0x80, i; +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) { + int mask = 0x80, i; - for (i=0; i<4; ++i, mask>>=1) { - if (channel & mask) { - if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); - dest[i]=stbi__get8(s); - } - } + for (i = 0; i < 4; ++i, mask >>= 1) { + if (channel & mask) { + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "PIC file too short"); + dest[i] = stbi__get8(s); + } + } - return dest; + return dest; } -static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) -{ - int mask=0x80,i; +static void stbi__copyval(int channel, stbi_uc *dest, const stbi_uc *src) { + int mask = 0x80, i; - for (i=0;i<4; ++i, mask>>=1) - if (channel&mask) - dest[i]=src[i]; + for (i = 0; i < 4; ++i, mask >>= 1) + if (channel & mask) + dest[i] = src[i]; } -static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) -{ - int act_comp=0,num_packets=0,y,chained; - stbi__pic_packet packets[10]; +static stbi_uc *stbi__pic_load_core(stbi__context *s, int width, int height, + int *comp, stbi_uc *result) { + int act_comp = 0, num_packets = 0, y, chained; + stbi__pic_packet packets[10]; - // this will (should...) cater for even some bizarre stuff like having data - // for the same channel in multiple packets. - do { - stbi__pic_packet *packet; + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return stbi__errpuc("bad format","too many packets"); + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return stbi__errpuc("bad format", "too many packets"); - packet = &packets[num_packets++]; + packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); - act_comp |= packet->channel; + act_comp |= packet->channel; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); - if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); - } while (chained); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (reading packets)"); + if (packet->size != 8) + return stbi__errpuc("bad format", "packet isn't 8bpp"); + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? - for(y=0; ytype) { - default: - return stbi__errpuc("bad format","packet has bad compression type"); + switch (packet->type) { + default: + return stbi__errpuc("bad format", "packet has bad compression type"); - case 0: {//uncompressed - int x; + case 0: { // uncompressed + int x; - for(x=0;xchannel,dest)) - return 0; - break; - } + for (x = 0; x < width; ++x, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + break; + } - case 1://Pure RLE - { - int left=width, i; - - while (left>0) { - stbi_uc count,value[4]; - - count=stbi__get8(s); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); - - if (count > left) - count = (stbi_uc) left; - - if (!stbi__readval(s,packet->channel,value)) return 0; - - for(i=0; ichannel,dest,value); - left -= count; - } - } - break; - - case 2: {//Mixed RLE - int left=width; - while (left>0) { - int count = stbi__get8(s), i; - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); - - if (count >= 128) { // Repeated - stbi_uc value[4]; - - if (count==128) - count = stbi__get16be(s); - else - count -= 127; - if (count > left) - return stbi__errpuc("bad file","scanline overrun"); - - if (!stbi__readval(s,packet->channel,value)) - return 0; - - for(i=0;ichannel,dest,value); - } else { // Raw - ++count; - if (count>left) return stbi__errpuc("bad file","scanline overrun"); - - for(i=0;ichannel,dest)) - return 0; - } - left-=count; - } - break; - } - } + case 1: // Pure RLE + { + int left = width, i; + + while (left > 0) { + stbi_uc count, value[4]; + + count = stbi__get8(s); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pure read count)"); + + if (count > left) + count = (stbi_uc)left; + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + left -= count; + } + } break; + + case 2: { // Mixed RLE + int left = width; + while (left > 0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", + "file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count == 128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + } else { // Raw + ++count; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + for (i = 0; i < count; ++i, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + } + left -= count; + } + break; + } } - } + } + } - return result; + return result; } -static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) -{ - stbi_uc *result; - int i, x,y, internal_comp; - STBI_NOTUSED(ri); +static void *stbi__pic_load(stbi__context *s, int *px, int *py, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *result; + int i, x, y, internal_comp; + STBI_NOTUSED(ri); - if (!comp) comp = &internal_comp; + if (!comp) + comp = &internal_comp; - for (i=0; i<92; ++i) - stbi__get8(s); + for (i = 0; i < 92; ++i) + stbi__get8(s); - x = stbi__get16be(s); - y = stbi__get16be(s); + x = stbi__get16be(s); + y = stbi__get16be(s); - if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); - if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) + return stbi__errpuc("too large", "PIC image too large to decode"); - stbi__get32be(s); //skip `ratio' - stbi__get16be(s); //skip `fields' - stbi__get16be(s); //skip `pad' + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' - // intermediate buffer is RGBA - result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); - memset(result, 0xff, x*y*4); + // intermediate buffer is RGBA + result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); + memset(result, 0xff, x * y * 4); - if (!stbi__pic_load_core(s,x,y,comp, result)) { - STBI_FREE(result); - result=0; - } - *px = x; - *py = y; - if (req_comp == 0) req_comp = *comp; - result=stbi__convert_format(result,4,req_comp,x,y); + if (!stbi__pic_load_core(s, x, y, comp, result)) { + STBI_FREE(result); + result = 0; + } + *px = x; + *py = y; + if (req_comp == 0) + req_comp = *comp; + result = stbi__convert_format(result, 4, req_comp, x, y); - return result; + return result; } -static int stbi__pic_test(stbi__context *s) -{ - int r = stbi__pic_test_core(s); - stbi__rewind(s); - return r; +static int stbi__pic_test(stbi__context *s) { + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; } #endif @@ -6368,514 +7124,539 @@ static int stbi__pic_test(stbi__context *s) // GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb #ifndef STBI_NO_GIF -typedef struct -{ - stbi__int16 prefix; - stbi_uc first; - stbi_uc suffix; +typedef struct { + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; } stbi__gif_lzw; -typedef struct -{ - int w,h; - stbi_uc *out; // output buffer (always 4 components) - stbi_uc *background; // The current "background" as far as a gif is concerned - stbi_uc *history; - int flags, bgindex, ratio, transparent, eflags; - stbi_uc pal[256][4]; - stbi_uc lpal[256][4]; - stbi__gif_lzw codes[8192]; - stbi_uc *color_table; - int parse, step; - int lflags; - int start_x, start_y; - int max_x, max_y; - int cur_x, cur_y; - int line_size; - int delay; +typedef struct { + int w, h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; } stbi__gif; -static int stbi__gif_test_raw(stbi__context *s) -{ - int sz; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; - sz = stbi__get8(s); - if (sz != '9' && sz != '7') return 0; - if (stbi__get8(s) != 'a') return 0; - return 1; -} - -static int stbi__gif_test(stbi__context *s) -{ - int r = stbi__gif_test_raw(s); - stbi__rewind(s); - return r; -} - -static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) -{ - int i; - for (i=0; i < num_entries; ++i) { - pal[i][2] = stbi__get8(s); - pal[i][1] = stbi__get8(s); - pal[i][0] = stbi__get8(s); - pal[i][3] = transp == i ? 0 : 255; - } -} +static int stbi__gif_test_raw(stbi__context *s) { + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') + return 0; + if (stbi__get8(s) != 'a') + return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) { + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], + int num_entries, int transp) { + int i; + for (i = 0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, + int is_info) { + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || + stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') + return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') + return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + + if (comp != 0) + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the + // comments + + if (is_info) + return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) { + stbi__gif *g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind(s); + return 0; + } + if (x) + *x = g->w; + if (y) + *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) { + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) + return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) { + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) + return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc)init_code; + g->codes[init_code].suffix = (stbi_uc)init_code; + } + + // support no starting clear code + avail = clear + 2; + oldcode = -1; + + len = 0; + for (;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32)stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s, len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } -static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) -{ - stbi_uc version; - if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') - return stbi__err("not GIF", "Corrupt GIF"); + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } - version = stbi__get8(s); - if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); - if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + p->prefix = (stbi__int16)oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - stbi__g_failure_reason = ""; - g->w = stbi__get16le(s); - g->h = stbi__get16le(s); - g->flags = stbi__get8(s); - g->bgindex = stbi__get8(s); - g->ratio = stbi__get8(s); - g->transparent = -1; + stbi__out_gif_code(g, (stbi__uint16)code); - if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); - if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } - if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image +// doesn't support it two back is the image from two frames ago, used for a very +// specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, + int req_comp, stbi_uc *two_back) { + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour + // (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp, 0)) + return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *)stbi__malloc(4 * pcount); + g->background = (stbi_uc *)stbi__malloc(4 * pcount); + g->history = (stbi_uc *)stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); - if (is_info) return 1; + // image is treated as "transparent" at the start - ie, nothing overwrites + // the current background; background colour is only used for pixels that + // are not rendered first frame, after that "background" color refers to the + // color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, + 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, + pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the + // old background + } - if (g->flags & 0x80) - stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } - return 1; -} + // background is what out is after the undoing of the previou frame; + memcpy(g->background, g->out, 4 * g->w * g->h); + } + + // clear my history; + memset(g->history, 0x00, + g->w * g->h); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; -static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) -{ - stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); - if (!stbi__gif_header(s, g, comp, 1)) { - STBI_FREE(g); - stbi__rewind( s ); - return 0; - } - if (x) *x = g->w; - if (y) *y = g->h; - STBI_FREE(g); - return 1; -} + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; -static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) -{ - stbi_uc *p, *c; - int idx; - - // recurse to decode the prefixes, since the linked-list is backwards, - // and working backwards through an interleaved image would be nasty - if (g->codes[code].prefix >= 0) - stbi__out_gif_code(g, g->codes[code].prefix); - - if (g->cur_y >= g->max_y) return; - - idx = g->cur_x + g->cur_y; - p = &g->out[idx]; - g->history[idx / 4] = 1; - - c = &g->color_table[g->codes[code].suffix * 4]; - if (c[3] > 128) { // don't render transparent pixels; - p[0] = c[2]; - p[1] = c[1]; - p[2] = c[0]; - p[3] = c[3]; - } - g->cur_x += 4; - - if (g->cur_x >= g->max_x) { - g->cur_x = g->start_x; - g->cur_y += g->step; + g->lflags = stbi__get8(s); - while (g->cur_y >= g->max_y && g->parse > 0) { - g->step = (1 << g->parse) * g->line_size; - g->cur_y = g->start_y + (g->step >> 1); - --g->parse; + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; } - } -} -static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) -{ - stbi_uc lzw_cs; - stbi__int32 len, init_code; - stbi__uint32 first; - stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; - stbi__gif_lzw *p; - - lzw_cs = stbi__get8(s); - if (lzw_cs > 12) return NULL; - clear = 1 << lzw_cs; - first = 1; - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - bits = 0; - valid_bits = 0; - for (init_code = 0; init_code < clear; init_code++) { - g->codes[init_code].prefix = -1; - g->codes[init_code].first = (stbi_uc) init_code; - g->codes[init_code].suffix = (stbi_uc) init_code; - } - - // support no starting clear code - avail = clear+2; - oldcode = -1; - - len = 0; - for(;;) { - if (valid_bits < codesize) { - if (len == 0) { - len = stbi__get8(s); // start new block - if (len == 0) - return g->out; - } - --len; - bits |= (stbi__int32) stbi__get8(s) << valid_bits; - valid_bits += 8; - } else { - stbi__int32 code = bits & codemask; - bits >>= codesize; - valid_bits -= codesize; - // @OPTIMIZE: is there some way we can accelerate the non-clear path? - if (code == clear) { // clear code - codesize = lzw_cs + 1; - codemask = (1 << codesize) - 1; - avail = clear + 2; - oldcode = -1; - first = 0; - } else if (code == clear + 1) { // end of stream code - stbi__skip(s, len); - while ((len = stbi__get8(s)) > 0) - stbi__skip(s,len); - return g->out; - } else if (code <= avail) { - if (first) { - return stbi__errpuc("no clear code", "Corrupt GIF"); - } + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), + g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *)g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *)g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); - if (oldcode >= 0) { - p = &g->codes[avail++]; - if (avail > 8192) { - return stbi__errpuc("too many codes", "Corrupt GIF"); - } + o = stbi__process_gif_raster(s, g); + if (!o) + return NULL; - p->prefix = (stbi__int16) oldcode; - p->first = g->codes[oldcode].first; - p->suffix = (code == avail) ? p->first : g->codes[code].first; - } else if (code == avail) - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = + 255; // just in case it was made transparent, undo that; It will + // be reset next frame if need be; + memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); + } + } + } - stbi__out_gif_code(g, (stbi__uint16) code); + return o; + } - if ((avail & codemask) == 0 && avail <= 0x0FFF) { - codesize++; - codemask = (1 << codesize) - 1; + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = + 10 * stbi__get16le( + s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; } - - oldcode = code; - } else { - return stbi__errpuc("illegal code in raster", "Corrupt GIF"); - } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } } - } -} - -// this function is designed to support animated gifs, although stb_image doesn't support it -// two back is the image from two frames ago, used for a very specific disposal format -static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) -{ - int dispose; - int first_frame; - int pi; - int pcount; - STBI_NOTUSED(req_comp); - - // on first frame, any non-written pixels get the background colour (non-transparent) - first_frame = 0; - if (g->out == 0) { - if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header - if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) - return stbi__errpuc("too large", "GIF image is too large"); - pcount = g->w * g->h; - g->out = (stbi_uc *) stbi__malloc(4 * pcount); - g->background = (stbi_uc *) stbi__malloc(4 * pcount); - g->history = (stbi_uc *) stbi__malloc(pcount); - if (!g->out || !g->background || !g->history) - return stbi__errpuc("outofmem", "Out of memory"); - - // image is treated as "transparent" at the start - ie, nothing overwrites the current background; - // background colour is only used for pixels that are not rendered first frame, after that "background" - // color refers to the color that was there the previous frame. - memset(g->out, 0x00, 4 * pcount); - memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) - memset(g->history, 0x00, pcount); // pixels that were affected previous frame - first_frame = 1; - } else { - // second frame - how do we dispose of the previous one? - dispose = (g->eflags & 0x1C) >> 2; - pcount = g->w * g->h; - - if ((dispose == 3) && (two_back == 0)) { - dispose = 2; // if I don't have an image to revert back to, default to the old background + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); } + break; + } - if (dispose == 3) { // use previous graphic - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); - } - } - } else if (dispose == 2) { - // restore what was changed last frame to background before that frame; - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi]) { - memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); - } - } - } else { - // This is a non-disposal case eithe way, so just - // leave the pixels as is, and they will become the new background - // 1: do not dispose - // 0: not specified. - } + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers - // background is what out is after the undoing of the previou frame; - memcpy( g->background, g->out, 4 * g->w * g->h ); - } - - // clear my history; - memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame - - for (;;) { - int tag = stbi__get8(s); - switch (tag) { - case 0x2C: /* Image Descriptor */ - { - stbi__int32 x, y, w, h; - stbi_uc *o; - - x = stbi__get16le(s); - y = stbi__get16le(s); - w = stbi__get16le(s); - h = stbi__get16le(s); - if (((x + w) > (g->w)) || ((y + h) > (g->h))) - return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); - - g->line_size = g->w * 4; - g->start_x = x * 4; - g->start_y = y * g->line_size; - g->max_x = g->start_x + w * 4; - g->max_y = g->start_y + h * g->line_size; - g->cur_x = g->start_x; - g->cur_y = g->start_y; - - // if the width of the specified rectangle is 0, that means - // we may not see *any* pixels or the image is malformed; - // to make sure this is caught, move the current y down to - // max_y (which is what out_gif_code checks). - if (w == 0) - g->cur_y = g->max_y; - - g->lflags = stbi__get8(s); - - if (g->lflags & 0x40) { - g->step = 8 * g->line_size; // first interlaced spacing - g->parse = 3; - } else { - g->step = g->line_size; - g->parse = 0; - } + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, + int *z, int *comp, int req_comp) { + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } - if (g->lflags & 0x80) { - stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); - g->color_table = (stbi_uc *) g->lpal; - } else if (g->flags & 0x80) { - g->color_table = (stbi_uc *) g->pal; - } else - return stbi__errpuc("missing color table", "Corrupt GIF"); - - o = stbi__process_gif_raster(s, g); - if (!o) return NULL; - - // if this was the first frame, - pcount = g->w * g->h; - if (first_frame && (g->bgindex > 0)) { - // if first frame, any pixel not drawn to gets the background color - for (pi = 0; pi < pcount; ++pi) { - if (g->history[pi] == 0) { - g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; - memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); - } - } - } + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = + (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + if (NULL == tmp) { + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + return stbi__errpuc("outofmem", "Out of memory"); + } else { + out = (stbi_uc *)tmp; + out_size = layers * stride; + } + + if (delays) { + *delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, + sizeof(int) * layers); + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc *)stbi__malloc(layers * stride); + out_size = layers * stride; + if (delays) { + *delays = (int *)stbi__malloc(layers * sizeof(int)); + delays_size = layers * sizeof(int); + } + } + memcpy(out + ((layers - 1) * stride), u, stride); + if (layers >= 2) { + two_back = out - 2 * stride; + } - return o; - } - - case 0x21: // Comment Extension. - { - int len; - int ext = stbi__get8(s); - if (ext == 0xF9) { // Graphic Control Extension. - len = stbi__get8(s); - if (len == 4) { - g->eflags = stbi__get8(s); - g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. - - // unset old transparent - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 255; - } - if (g->eflags & 0x01) { - g->transparent = stbi__get8(s); - if (g->transparent >= 0) { - g->pal[g->transparent][3] = 0; - } - } else { - // don't need transparent - stbi__skip(s, 1); - g->transparent = -1; - } - } else { - stbi__skip(s, len); - break; - } - } - while ((len = stbi__get8(s)) != 0) { - stbi__skip(s, len); - } - break; - } + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); - case 0x3B: // gif stream termination code - return (stbi_uc *) s; // using '1' causes warning on some compilers + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); - default: - return stbi__errpuc("unknown code", "Corrupt GIF"); - } - } -} + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); -static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) -{ - if (stbi__gif_test(s)) { - int layers = 0; - stbi_uc *u = 0; - stbi_uc *out = 0; - stbi_uc *two_back = 0; - stbi__gif g; - int stride; - int out_size = 0; - int delays_size = 0; - memset(&g, 0, sizeof(g)); - if (delays) { - *delays = 0; - } + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} - do { - u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - - if (u) { - *x = g.w; - *y = g.h; - ++layers; - stride = g.w * g.h * 4; - - if (out) { - void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); - if (NULL == tmp) { - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); - return stbi__errpuc("outofmem", "Out of memory"); - } - else { - out = (stbi_uc*) tmp; - out_size = layers * stride; - } - - if (delays) { - *delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); - delays_size = layers * sizeof(int); - } - } else { - out = (stbi_uc*)stbi__malloc( layers * stride ); - out_size = layers * stride; - if (delays) { - *delays = (int*) stbi__malloc( layers * sizeof(int) ); - delays_size = layers * sizeof(int); - } - } - memcpy( out + ((layers - 1) * stride), u, stride ); - if (layers >= 2) { - two_back = out - 2 * stride; - } +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); - if (delays) { - (*delays)[layers - 1U] = g.delay; - } - } - } while (u != 0); + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; - // free temp buffer; - STBI_FREE(g.out); - STBI_FREE(g.history); - STBI_FREE(g.background); + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } - // do the final conversion after loading everything; - if (req_comp && req_comp != 4) - out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); - *z = layers; - return out; - } else { - return stbi__errpuc("not GIF", "Image was not as a gif type."); - } + return u; } -static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *u = 0; - stbi__gif g; - memset(&g, 0, sizeof(g)); - STBI_NOTUSED(ri); - - u = stbi__gif_load_next(s, &g, comp, req_comp, 0); - if (u == (stbi_uc *) s) u = 0; // end of animated gif marker - if (u) { - *x = g.w; - *y = g.h; - - // moved conversion to after successful load so that the same - // can be done for multiple frames. - if (req_comp && req_comp != 4) - u = stbi__convert_format(u, 4, req_comp, g.w, g.h); - } else if (g.out) { - // if there was an error and we allocated an image buffer, free it! - STBI_FREE(g.out); - } - - // free buffers needed for multiple frame loading; - STBI_FREE(g.history); - STBI_FREE(g.background); - - return u; -} - -static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) -{ - return stbi__gif_info_raw(s,x,y,comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) { + return stbi__gif_info_raw(s, x, y, comp); } #endif @@ -6883,396 +7664,434 @@ static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) // Radiance RGBE HDR loader // originally by Nicolas Schulz #ifndef STBI_NO_HDR -static int stbi__hdr_test_core(stbi__context *s, const char *signature) -{ - int i; - for (i=0; signature[i]; ++i) - if (stbi__get8(s) != signature[i]) - return 0; - stbi__rewind(s); - return 1; -} - -static int stbi__hdr_test(stbi__context* s) -{ - int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); - stbi__rewind(s); - if(!r) { - r = stbi__hdr_test_core(s, "#?RGBE\n"); - stbi__rewind(s); - } - return r; -} - -#define STBI__HDR_BUFLEN 1024 -static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) -{ - int len=0; - char c = '\0'; - - c = (char) stbi__get8(z); - - while (!stbi__at_eof(z) && c != '\n') { - buffer[len++] = c; - if (len == STBI__HDR_BUFLEN-1) { - // flush to end of line - while (!stbi__at_eof(z) && stbi__get8(z) != '\n') - ; - break; +static int stbi__hdr_test_core(stbi__context *s, const char *signature) { + int i; + for (i = 0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context *s) { + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if (!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) { + int len = 0; + char c = '\0'; + + c = (char)stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN - 1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char)stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) { + if (input[3] != 0) { + float f1; + // Exponent + f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) + output[1] = 1; + if (req_comp == 4) + output[3] = 1; + } else { + switch (req_comp) { + case 4: + output[3] = 1; /* fallthrough */ + case 3: + output[0] = output[1] = output[2] = 0; + break; + case 2: + output[1] = 1; /* fallthrough */ + case 1: + output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1, c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s, buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && + strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) + return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int)strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) + *comp = 3; + if (req_comp == 0) + req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = + (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if (width < 8 || width >= 32768) { + // Read flat data + for (j = 0; j < height; ++j) { + for (i = 0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, + req_comp); } - c = (char) stbi__get8(z); - } - - buffer[len] = 0; - return buffer; -} + } + } else { + // Read RLE-encoded data + scanline = NULL; -static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) -{ - if ( input[3] != 0 ) { - float f1; - // Exponent - f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); - if (req_comp <= 2) - output[0] = (input[0] + input[1] + input[2]) * f1 / 3; - else { - output[0] = input[0] * f1; - output[1] = input[1] * f1; - output[2] = input[2] * f1; + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a + // decoded pixel (note this can't be a valid pixel--one of RGB must be + // >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc)c1; + rgbe[1] = (stbi_uc)c2; + rgbe[2] = (stbi_uc)len; + rgbe[3] = (stbi_uc)stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense } - if (req_comp == 2) output[1] = 1; - if (req_comp == 4) output[3] = 1; - } else { - switch (req_comp) { - case 4: output[3] = 1; /* fallthrough */ - case 3: output[0] = output[1] = output[2] = 0; - break; - case 2: output[1] = 1; /* fallthrough */ - case 1: output[0] = 0; - break; + len <<= 8; + len |= stbi__get8(s); + if (len != width) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - } -} - -static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int width, height; - stbi_uc *scanline; - float *hdr_data; - int len; - unsigned char count, value; - int i, j, k, c1,c2, z; - const char *headerToken; - STBI_NOTUSED(ri); - - // Check identifier - headerToken = stbi__hdr_gettoken(s,buffer); - if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) - return stbi__errpf("not HDR", "Corrupt HDR image"); - - // Parse header - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); - - // Parse width and height - // can't use sscanf() if we're not using stdio! - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - height = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); - token += 3; - width = (int) strtol(token, NULL, 10); - - if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); - - *x = width; - *y = height; - - if (comp) *comp = 3; - if (req_comp == 0) req_comp = 3; - - if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) - return stbi__errpf("too large", "HDR image is too large"); - - // Read data - hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); - if (!hdr_data) - return stbi__errpf("outofmem", "Out of memory"); - - // Load image data - // image data is stored as some number of sca - if ( width < 8 || width >= 32768) { - // Read flat data - for (j=0; j < height; ++j) { - for (i=0; i < width; ++i) { - stbi_uc rgbe[4]; - main_decode_loop: - stbi__getn(s, rgbe, 4); - stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); - } + if (scanline == NULL) { + scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } } - } else { - // Read RLE-encoded data - scanline = NULL; - - for (j = 0; j < height; ++j) { - c1 = stbi__get8(s); - c2 = stbi__get8(s); - len = stbi__get8(s); - if (c1 != 2 || c2 != 2 || (len & 0x80)) { - // not run-length encoded, so we have to actually use THIS data as a decoded - // pixel (note this can't be a valid pixel--one of RGB must be >= 128) - stbi_uc rgbe[4]; - rgbe[0] = (stbi_uc) c1; - rgbe[1] = (stbi_uc) c2; - rgbe[2] = (stbi_uc) len; - rgbe[3] = (stbi_uc) stbi__get8(s); - stbi__hdr_convert(hdr_data, rgbe, req_comp); - i = 1; - j = 0; - STBI_FREE(scanline); - goto main_decode_loop; // yes, this makes no sense - } - len <<= 8; - len |= stbi__get8(s); - if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } - if (scanline == NULL) { - scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); - if (!scanline) { - STBI_FREE(hdr_data); - return stbi__errpf("outofmem", "Out of memory"); + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - - for (k = 0; k < 4; ++k) { - int nleft; - i = 0; - while ((nleft = width - i) > 0) { - count = stbi__get8(s); - if (count > 128) { - // Run - value = stbi__get8(s); - count -= 128; - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = value; - } else { - // Dump - if (count > nleft) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } - for (z = 0; z < count; ++z) - scanline[i++ * 4 + k] = stbi__get8(s); - } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if (count > nleft) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); } - } - for (i=0; i < width; ++i) - stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } } - if (scanline) - STBI_FREE(scanline); - } - - return hdr_data; -} - -static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) -{ - char buffer[STBI__HDR_BUFLEN]; - char *token; - int valid = 0; - int dummy; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (stbi__hdr_test(s) == 0) { - stbi__rewind( s ); - return 0; - } - - for(;;) { - token = stbi__hdr_gettoken(s,buffer); - if (token[0] == 0) break; - if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; - } - - if (!valid) { - stbi__rewind( s ); - return 0; - } - token = stbi__hdr_gettoken(s,buffer); - if (strncmp(token, "-Y ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *y = (int) strtol(token, &token, 10); - while (*token == ' ') ++token; - if (strncmp(token, "+X ", 3)) { - stbi__rewind( s ); - return 0; - } - token += 3; - *x = (int) strtol(token, NULL, 10); - *comp = 3; - return 1; + for (i = 0; i < width; ++i) + stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, + scanline + i * 4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) { + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind(s); + return 0; + } + + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) { + stbi__rewind(s); + return 0; + } + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *y = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *x = (int)strtol(token, NULL, 10); + *comp = 3; + return 1; } #endif // STBI_NO_HDR #ifndef STBI_NO_BMP -static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) -{ - void *p; - stbi__bmp_data info; - - info.all_a = 255; - p = stbi__bmp_parse_header(s, &info); - stbi__rewind( s ); - if (p == NULL) - return 0; - if (x) *x = s->img_x; - if (y) *y = s->img_y; - if (comp) { - if (info.bpp == 24 && info.ma == 0xff000000) - *comp = 3; - else - *comp = info.ma ? 4 : 3; - } - return 1; +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) { + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + stbi__rewind(s); + if (p == NULL) + return 0; + if (x) + *x = s->img_x; + if (y) + *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; } #endif #ifndef STBI_NO_PSD -static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) -{ - int channelCount, dummy, depth; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - *y = stbi__get32be(s); - *x = stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 8 && depth != 16) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 3) { - stbi__rewind( s ); - return 0; - } - *comp = 4; - return 1; -} - -static int stbi__psd_is16(stbi__context *s) -{ - int channelCount, depth; - if (stbi__get32be(s) != 0x38425053) { - stbi__rewind( s ); - return 0; - } - if (stbi__get16be(s) != 1) { - stbi__rewind( s ); - return 0; - } - stbi__skip(s, 6); - channelCount = stbi__get16be(s); - if (channelCount < 0 || channelCount > 16) { - stbi__rewind( s ); - return 0; - } - (void) stbi__get32be(s); - (void) stbi__get32be(s); - depth = stbi__get16be(s); - if (depth != 16) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) { + int channelCount, dummy, depth; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind(s); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) { + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + (void)stbi__get32be(s); + (void)stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind(s); + return 0; + } + return 1; } #endif #ifndef STBI_NO_PIC -static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) -{ - int act_comp=0,num_packets=0,chained,dummy; - stbi__pic_packet packets[10]; - - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; - - if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { - stbi__rewind(s); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) { + int act_comp = 0, num_packets = 0, chained, dummy; + stbi__pic_packet packets[10]; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) return 0; - } - stbi__skip(s, 88); + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; - *x = stbi__get16be(s); - *y = stbi__get16be(s); - if (stbi__at_eof(s)) { - stbi__rewind( s); + if (stbi__at_eof(s)) { + stbi__rewind(s); return 0; - } - if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { - stbi__rewind( s ); + } + if (packet->size != 8) { + stbi__rewind(s); return 0; - } - - stbi__skip(s, 8); - - do { - stbi__pic_packet *packet; - - if (num_packets==sizeof(packets)/sizeof(packets[0])) - return 0; - - packet = &packets[num_packets++]; - chained = stbi__get8(s); - packet->size = stbi__get8(s); - packet->type = stbi__get8(s); - packet->channel = stbi__get8(s); - act_comp |= packet->channel; - - if (stbi__at_eof(s)) { - stbi__rewind( s ); - return 0; - } - if (packet->size != 8) { - stbi__rewind( s ); - return 0; - } - } while (chained); + } + } while (chained); - *comp = (act_comp & 0x10 ? 4 : 3); + *comp = (act_comp & 0x10 ? 4 : 3); - return 1; + return 1; } #endif @@ -7290,257 +8109,266 @@ static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) #ifndef STBI_NO_PNM -static int stbi__pnm_test(stbi__context *s) -{ - char p, t; - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind( s ); - return 0; - } - return 1; +static int stbi__pnm_test(stbi__context *s) { + char p, t; + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + return 1; } -static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) -{ - stbi_uc *out; - STBI_NOTUSED(ri); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, + int req_comp, stbi__result_info *ri) { + stbi_uc *out; + STBI_NOTUSED(ri); - if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) - return 0; + if (!stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n)) + return 0; - if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); - if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); - *x = s->img_x; - *y = s->img_y; - if (comp) *comp = s->img_n; + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; - if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) - return stbi__errpuc("too large", "PNM too large"); + if (!stbi__mad3sizes_valid(s->img_n, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "PNM too large"); - out = (stbi_uc *) stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); - if (!out) return stbi__errpuc("outofmem", "Out of memory"); - stbi__getn(s, out, s->img_n * s->img_x * s->img_y); + out = (stbi_uc *)stbi__malloc_mad3(s->img_n, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + stbi__getn(s, out, s->img_n * s->img_x * s->img_y); - if (req_comp && req_comp != s->img_n) { - out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); - if (out == NULL) return out; // stbi__convert_format frees input on failure - } - return out; + if (req_comp && req_comp != s->img_n) { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + return out; } -static int stbi__pnm_isspace(char c) -{ - return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +static int stbi__pnm_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; } -static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) -{ - for (;;) { - while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) - *c = (char) stbi__get8(s); +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) { + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char)stbi__get8(s); - if (stbi__at_eof(s) || *c != '#') - break; + if (stbi__at_eof(s) || *c != '#') + break; - while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') + *c = (char)stbi__get8(s); + } } -static int stbi__pnm_isdigit(char c) -{ - return c >= '0' && c <= '9'; +static int stbi__pnm_isdigit(char c) { + return c >= '0' && c <= '9'; } -static int stbi__pnm_getinteger(stbi__context *s, char *c) -{ - int value = 0; +static int stbi__pnm_getinteger(stbi__context *s, char *c) { + int value = 0; - while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { - value = value*10 + (*c - '0'); - *c = (char) stbi__get8(s); - } + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value * 10 + (*c - '0'); + *c = (char)stbi__get8(s); + } - return value; + return value; } -static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) -{ - int maxv, dummy; - char c, p, t; +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) { + int maxv, dummy; + char c, p, t; - if (!x) x = &dummy; - if (!y) y = &dummy; - if (!comp) comp = &dummy; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; - stbi__rewind(s); + stbi__rewind(s); - // Get identifier - p = (char) stbi__get8(s); - t = (char) stbi__get8(s); - if (p != 'P' || (t != '5' && t != '6')) { - stbi__rewind(s); - return 0; - } + // Get identifier + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } - *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + *comp = + (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm - c = (char) stbi__get8(s); - stbi__pnm_skip_whitespace(s, &c); + c = (char)stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); - *x = stbi__pnm_getinteger(s, &c); // read width - stbi__pnm_skip_whitespace(s, &c); + *x = stbi__pnm_getinteger(s, &c); // read width + stbi__pnm_skip_whitespace(s, &c); - *y = stbi__pnm_getinteger(s, &c); // read height - stbi__pnm_skip_whitespace(s, &c); + *y = stbi__pnm_getinteger(s, &c); // read height + stbi__pnm_skip_whitespace(s, &c); - maxv = stbi__pnm_getinteger(s, &c); // read max value + maxv = stbi__pnm_getinteger(s, &c); // read max value - if (maxv > 255) - return stbi__err("max value > 255", "PPM image not 8-bit"); - else - return 1; + if (maxv > 255) + return stbi__err("max value > 255", "PPM image not 8-bit"); + else + return 1; } #endif -static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) -{ - #ifndef STBI_NO_JPEG - if (stbi__jpeg_info(s, x, y, comp)) return 1; - #endif +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) { +#ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNG - if (stbi__png_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_GIF - if (stbi__gif_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_BMP - if (stbi__bmp_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PIC - if (stbi__pic_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_PNM - if (stbi__pnm_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) + return 1; +#endif - #ifndef STBI_NO_HDR - if (stbi__hdr_info(s, x, y, comp)) return 1; - #endif +#ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) + return 1; +#endif - // test tga last because it's a crappy test! - #ifndef STBI_NO_TGA - if (stbi__tga_info(s, x, y, comp)) - return 1; - #endif - return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +// test tga last because it's a crappy test! +#ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; +#endif + return stbi__err("unknown image type", + "Image not of any known type, or corrupt"); } -static int stbi__is_16_main(stbi__context *s) -{ - #ifndef STBI_NO_PNG - if (stbi__png_is16(s)) return 1; - #endif +static int stbi__is_16_main(stbi__context *s) { +#ifndef STBI_NO_PNG + if (stbi__png_is16(s)) + return 1; +#endif - #ifndef STBI_NO_PSD - if (stbi__psd_is16(s)) return 1; - #endif +#ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) + return 1; +#endif - return 0; + return 0; } #ifndef STBI_NO_STDIO -STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_info_from_file(f, x, y, comp); - fclose(f); - return result; -} - -STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__info_main(&s,x,y,comp); - fseek(f,pos,SEEK_SET); - return r; -} - -STBIDEF int stbi_is_16_bit(char const *filename) -{ - FILE *f = stbi__fopen(filename, "rb"); - int result; - if (!f) return stbi__err("can't fopen", "Unable to open file"); - result = stbi_is_16_bit_from_file(f); - fclose(f); - return result; -} - -STBIDEF int stbi_is_16_bit_from_file(FILE *f) -{ - int r; - stbi__context s; - long pos = ftell(f); - stbi__start_file(&s, f); - r = stbi__is_16_main(&s); - fseek(f,pos,SEEK_SET); - return r; +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s, x, y, comp); + fseek(f, pos, SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) { + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f, pos, SEEK_SET); + return r; } #endif // !STBI_NO_STDIO -STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, + int *y, int *comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__info_main(&s,x,y,comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, + int *x, int *y, int *comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__info_main(&s, x, y, comp); } -STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) -{ - stbi__context s; - stbi__start_mem(&s,buffer,len); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__is_16_main(&s); } -STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) -{ - stbi__context s; - stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); - return stbi__is_16_main(&s); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, + void *user) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__is_16_main(&s); } #endif // STB_IMAGE_IMPLEMENTATION /* revision history: - 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs - 2.19 (2018-02-11) fix warning - 2.18 (2018-01-30) fix warnings - 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and + platform ifdefs 2.19 (2018-02-11) fix warning 2.18 (2018-01-30) fix + warnings 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug 1-bit BMP *_is_16_bit api avoid warnings @@ -7555,13 +8383,11 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user warning fixes; disable run-time SSE detection on gcc; uniform handling of optional "return" values; thread-safe initialization of zlib tables - 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs - 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now - 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes - 2.11 (2016-04-02) allocate large structures on the stack - remove white matting for transparent PSD - fix reported channel count for PNG & BMP - re-enable SSE2 in non-gcc 64-bit + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet + JPGs 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now 2.12 + (2016-04-02) fix typo in 2.11 PSD fix that caused crashes 2.11 (2016-04-02) + allocate large structures on the stack remove white matting for transparent + PSD fix reported channel count for PNG & BMP re-enable SSE2 in non-gcc 64-bit support RGB-formatted JPEG read 16-bit PNGs (only as 8-bit) 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED @@ -7569,11 +8395,9 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 16-bit-per-pixel TGA (not bit-per-component) info() for TGA could break due to .hdr handling info() for BMP to shares code instead of sloppy parse - can use STBI_REALLOC_SIZED if allocator doesn't support realloc - code cleanup - 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA - 2.07 (2015-09-13) fix compiler warnings - partial animated GIF support + can use STBI_REALLOC_SIZED if allocator doesn't support + realloc code cleanup 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD + as RGBA 2.07 (2015-09-13) fix compiler warnings partial animated GIF support limited 16-bpc PSD support #ifdef unused functions bug with < 92 byte PIC,PNM,HDR,TGA @@ -7584,23 +8408,18 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user stbi_set_flip_vertically_on_load (nguillemot) fix NEON support; fix mingw support 2.02 (2015-01-19) fix incorrect assert, fix warning - 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 - 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG - 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) - progressive JPEG (stb) - PGM/PPM support (Ken Miller) - STBI_MALLOC,STBI_REALLOC,STBI_FREE + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit + without -msse2 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG 2.00 + (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) progressive + JPEG (stb) PGM/PPM support (Ken Miller) STBI_MALLOC,STBI_REALLOC,STBI_FREE GIF bugfix -- seemingly never worked STBI_NO_*, STBI_ONLY_* 1.48 (2014-12-14) fix incorrectly-named assert() - 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) - optimize PNG (ryg) - fix bug in interlaced PNG with user-specified channel count (stb) - 1.46 (2014-08-26) - fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG - 1.45 (2014-08-16) - fix MSVC-ARM internal compiler error by wrapping malloc - 1.44 (2014-08-07) + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar + Cornut & stb) optimize PNG (ryg) fix bug in interlaced PNG with + user-specified channel count (stb) 1.46 (2014-08-26) fix broken tRNS chunk + (colorkey-style transparency) in non-paletted PNG 1.45 (2014-08-16) fix + MSVC-ARM internal compiler error by wrapping malloc 1.44 (2014-08-07) various warning fixes from Ronny Chevalier 1.43 (2014-07-15) fix MSVC-only compiler problem in code changed in 1.42 @@ -7609,73 +8428,48 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user fixes to stbi__cleanup_jpeg path added STBI_ASSERT to avoid requiring assert.h 1.41 (2014-06-25) - fix search&replace from 1.36 that messed up comments/error messages - 1.40 (2014-06-22) - fix gcc struct-initialization warning - 1.39 (2014-06-15) - fix to TGA optimization when req_comp != number of components in TGA; - fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) - add support for BMP version 5 (more ignored fields) - 1.38 (2014-06-06) - suppress MSVC warnings on integer casts truncating values - fix accidental rename of 'skip' field of I/O - 1.37 (2014-06-04) - remove duplicate typedef - 1.36 (2014-06-03) - convert to header file single-file library - if de-iphone isn't set, load iphone images color-swapped instead of returning NULL - 1.35 (2014-05-27) - various warnings - fix broken STBI_SIMD path - fix bug where stbi_load_from_file no longer left file pointer in correct place - fix broken non-easy path for 32-bit BMP (possibly never used) - TGA optimization by Arseny Kapoulkine - 1.34 (unknown) - use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case - 1.33 (2011-07-14) - make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements - 1.32 (2011-07-13) - support for "info" function for all supported filetypes (SpartanJ) - 1.31 (2011-06-20) - a few more leak fixes, bug in PNG handling (SpartanJ) - 1.30 (2011-06-11) - added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + fix search&replace from 1.36 that messed up comments/error + messages 1.40 (2014-06-22) fix gcc struct-initialization warning 1.39 + (2014-06-15) fix to TGA optimization when req_comp != number of components in + TGA; fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my + test suite) add support for BMP version 5 (more ignored fields) 1.38 + (2014-06-06) suppress MSVC warnings on integer casts truncating values fix + accidental rename of 'skip' field of I/O 1.37 (2014-06-04) remove duplicate + typedef 1.36 (2014-06-03) convert to header file single-file library if + de-iphone isn't set, load iphone images color-swapped instead of returning + NULL 1.35 (2014-05-27) various warnings fix broken STBI_SIMD path fix bug + where stbi_load_from_file no longer left file pointer in correct place fix + broken non-easy path for 32-bit BMP (possibly never used) TGA optimization by + Arseny Kapoulkine 1.34 (unknown) use STBI_NOTUSED in + stbi__resample_row_generic(), fix one more leak in tga failure case 1.33 + (2011-07-14) make stbi_is_hdr work in STBI_NO_HDR (as specified), minor + compiler-friendly improvements 1.32 (2011-07-13) support for "info" function + for all supported filetypes (SpartanJ) 1.31 (2011-06-20) a few more leak + fixes, bug in PNG handling (SpartanJ) 1.30 (2011-06-11) added ability to + load files via callbacks to accomidate custom input streams (Ben Wenger) removed deprecated format-specific test/load functions - removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway - error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) - fix inefficiency in decoding 32-bit BMP (David Woo) - 1.29 (2010-08-16) - various warning fixes from Aurelien Pocheville - 1.28 (2010-08-01) - fix bug in GIF palette transparency (SpartanJ) - 1.27 (2010-08-01) - cast-to-stbi_uc to fix warnings - 1.26 (2010-07-24) - fix bug in file buffering for PNG reported by SpartanJ - 1.25 (2010-07-17) - refix trans_data warning (Won Chun) - 1.24 (2010-07-12) - perf improvements reading from files on platforms with lock-heavy fgetc() - minor perf improvements for jpeg - deprecated type-specific functions so we'll get feedback if they're needed - attempt to fix trans_data warning (Won Chun) - 1.23 fixed bug in iPhone support - 1.22 (2010-07-10) - removed image *writing* support - stbi_info support from Jetro Lauha - GIF support from Jean-Marc Lienher + removed support for installable file formats (stbi_loader) -- + would have been broken for IO callbacks anyway error cases in bmp and tga + give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in + decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from + Aurelien Pocheville 1.28 (2010-08-01) fix bug in GIF palette transparency + (SpartanJ) 1.27 (2010-08-01) cast-to-stbi_uc to fix warnings 1.26 + (2010-07-24) fix bug in file buffering for PNG reported by SpartanJ 1.25 + (2010-07-17) refix trans_data warning (Won Chun) 1.24 (2010-07-12) perf + improvements reading from files on platforms with lock-heavy fgetc() minor + perf improvements for jpeg deprecated type-specific functions so we'll get + feedback if they're needed attempt to fix trans_data warning (Won Chun) 1.23 + fixed bug in iPhone support 1.22 (2010-07-10) removed image *writing* + support stbi_info support from Jetro Lauha GIF support from Jean-Marc Lienher iPhone PNG-extensions from James Brown - warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) - 1.21 fix use of 'stbi_uc' in header (reported by jon blow) - 1.20 added support for Softimage PIC, by Tom Seddon - 1.19 bug in interlaced PNG corruption check (found by ryg) - 1.18 (2008-08-02) - fix a threading bug (local mutable static) - 1.17 support interlaced PNG - 1.16 major bugfix - stbi__convert_format converted one too many pixels - 1.15 initialize some fields for thread safety - 1.14 fix threadsafe conversion bug - header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. + Janez (U+017D)emva) 1.21 fix use of 'stbi_uc' in header (reported by jon + blow) 1.20 added support for Softimage PIC, by Tom Seddon 1.19 bug in + interlaced PNG corruption check (found by ryg) 1.18 (2008-08-02) fix a + threading bug (local mutable static) 1.17 support interlaced PNG 1.16 + major bugfix - stbi__convert_format converted one too many pixels 1.15 + initialize some fields for thread safety 1.14 fix threadsafe conversion + bug header-file-only version (#define STBI_HEADER_FILE_ONLY before including) 1.13 threadsafe 1.12 const qualifiers in the API 1.11 Support installable IDCT, colorspace conversion routines @@ -7685,15 +8479,14 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz 1.07 attempt to fix C++ warning/errors again 1.06 attempt to fix C++ warning/errors again - 1.05 fix TGA loading to return correct *comp and use good luminance calc - 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free - 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR - 1.02 support for (subset of) HDR files, float interface for preferred access to them - 1.01 fix bug: possible bug in handling right-side up bmps... not sure - fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all - 1.00 interface to zlib that skips zlib header - 0.99 correct handling of alpha in palette - 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 1.05 fix TGA loading to return correct *comp and use good luminance + calc 1.04 default float alpha is 1, not 255; use 'void *' for + stbi_image_free 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR 1.02 support + for (subset of) HDR files, float interface for preferred access to them 1.01 + fix bug: possible bug in handling right-side up bmps... not sure fix bug: the + stbi__bmp_load() and stbi__tga_load() functions didn't work at all 1.00 + interface to zlib that skips zlib header 0.99 correct handling of alpha in + palette 0.98 TGA loader by lonesock; dynamically add loaders (untested) 0.97 jpeg errors on too large a file; also catch another malloc failure 0.96 fix detection of invalid v value - particleman@mollyrocket forum 0.95 during header scan, seek to markers in case of padding @@ -7706,8 +8499,8 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user 0.60 fix compiling as c++ 0.59 fix warnings: merge Dave Moore's -Wall fixes 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian - 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available - 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but + less than 16 available 0.56 fix bug: zlib uncompressed mode len vs. nlen 0.55 fix bug: restart_interval not initialized to 0 0.54 allow NULL for 'int *comp' 0.53 fix bug in png 3->4; speedup png decoding @@ -7718,7 +8511,6 @@ STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user first released version */ - /* ------------------------------------------------------------------------------ This software is available under 2 licenses -- choose whichever you prefer. diff --git a/sample_programs/ml_sample_programs/vision_models/resnet50-v1-12/compile.sh b/sample_programs/ml_sample_programs/vision_models/resnet50-v1-12/compile.sh index e4491c06..2bac35f5 100755 --- a/sample_programs/ml_sample_programs/vision_models/resnet50-v1-12/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/resnet50-v1-12/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp resnet50-v1-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp resnet50-v1-12.onnx mlir-translate -mlir-to-llvmir resnet50-v1-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/shufflenetv2_10/compile.sh b/sample_programs/ml_sample_programs/vision_models/shufflenetv2_10/compile.sh index 761d4501..6adb6485 100755 --- a/sample_programs/ml_sample_programs/vision_models/shufflenetv2_10/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/shufflenetv2_10/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp shufflenet-v2-10.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp shufflenet-v2-10.onnx mlir-translate -mlir-to-llvmir shufflenet-v2-10.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/squeezenet1.0-12/compile.sh b/sample_programs/ml_sample_programs/vision_models/squeezenet1.0-12/compile.sh index e39e3ede..19538f10 100755 --- a/sample_programs/ml_sample_programs/vision_models/squeezenet1.0-12/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/squeezenet1.0-12/compile.sh @@ -9,7 +9,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp squeezenet1.0-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp squeezenet1.0-12.onnx mlir-translate -mlir-to-llvmir squeezenet1.0-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/vgg16-12/compile.sh b/sample_programs/ml_sample_programs/vision_models/vgg16-12/compile.sh index 61cd3515..28b0f7f7 100755 --- a/sample_programs/ml_sample_programs/vision_models/vgg16-12/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/vgg16-12/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp vgg16-12.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp vgg16-12.onnx mlir-translate -mlir-to-llvmir vgg16-12.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/sample_programs/ml_sample_programs/vision_models/yolo3/compile.sh b/sample_programs/ml_sample_programs/vision_models/yolo3/compile.sh index 702a9e1b..6ca75fe5 100644 --- a/sample_programs/ml_sample_programs/vision_models/yolo3/compile.sh +++ b/sample_programs/ml_sample_programs/vision_models/yolo3/compile.sh @@ -8,7 +8,7 @@ else fi printf "\n[Compile Script]: Convert TF model to LLVM IR\n" -onnx-mlir --EmitLLVMIR --instrument-onnx-ops="ALL" --InstrumentBeforeAndAfterOp yolov3-10.onnx +onnx-mlir --EmitLLVMIR --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentBeforeAndAfterOp yolov3-10.onnx mlir-translate -mlir-to-llvmir yolov3-10.onnx.mlir > model.mlir.ll printf "\n[Compile Script]: Compile main driver program and link to TF model in LLVM IR\n" diff --git a/setup b/setup index e647f079..ad2e5bb3 100755 --- a/setup +++ b/setup @@ -13,14 +13,13 @@ List of options: --help(-h): show help information ---runTests: Add this option if you want to run all regression tests after building LLTFI. +--runTests: Add this option if you want to run all standard regression tests after building LLTFI (equivalent to --all). Note: ML/ONNX tests require additional dependencies and must be run separately with: python3 SCRIPTS/llfi_test --all_ml """ import sys import os import re import getopt -import distutils.core import signal import subprocess import shutil @@ -35,7 +34,6 @@ required_configs = { } optional_configs = { "LLVM_GXX_BIN_DIR": "", - "JAVA_HOME_DIR": "", } run_tests_flag = False @@ -131,7 +129,7 @@ def checkEnvironment(): try: import yaml - except: + except Exception: print(("ERROR: You need to install Python Yaml module " "before building LLFI.\n"), file=sys.stderr) has_error = True @@ -149,22 +147,8 @@ def build(): llvm_paths_cmake = os.path.join(script_path, "config/llvm_paths.cmake") llvm_paths_py = os.path.join(script_path, "config/llvm_paths.py") llvm_paths_make = os.path.join(script_path, "config/llvm_paths.make") - java_paths_cmake = os.path.join(script_path, "config/java_paths.cmake") - java_paths_py = os.path.join(script_path, "config/java_paths.py") - - # build all default software failures - #fidl_path = os.path.join(script_path, 'tools/FIDL/FIDL-Algorithm.py') - #execlist = [fidl_path, '-a', 'default'] - #print('-- Generating default software failures with %s...' % (fidl_path)) - #p = subprocess.call(execlist, stdout = open(os.devnull, 'wb')) - #if p != 0: - # print('-- Failed!') - # sys.exit(p) - #else: - # print('-- Success!') - + # write cmake config file for llvm paths - cmake_File = open(llvm_paths_cmake, "w") LLVM_DST_ROOT = os.path.realpath(required_configs['LLVM_DST_ROOT']) LLVM_SRC_ROOT = os.path.realpath(required_configs['LLVM_SRC_ROOT']) if optional_configs['LLVM_GXX_BIN_DIR'] != "": @@ -172,23 +156,21 @@ def build(): else: LLVM_GXX_BIN_DIR = "" - cmake_File.write("set(LLVM_DST_ROOT " + LLVM_DST_ROOT + ")\n") - cmake_File.write("set(LLVM_SRC_ROOT " + LLVM_SRC_ROOT + ")\n") - cmake_File.close() + with open(llvm_paths_cmake, "w") as cmake_File: + cmake_File.write("set(LLVM_DST_ROOT " + LLVM_DST_ROOT + ")\n") + cmake_File.write("set(LLVM_SRC_ROOT " + LLVM_SRC_ROOT + ")\n") # write python config file for llvm paths - py_File = open(llvm_paths_py, "w") - py_File.write("LLVM_DST_ROOT = " + '"' + LLVM_DST_ROOT + '"\n') - py_File.write("LLVM_SRC_ROOT = " + '"' + LLVM_SRC_ROOT + '"\n') - py_File.write("LLVM_GXX_BIN_DIR = " + '"' + LLVM_GXX_BIN_DIR + '"\n') - py_File.close() + with open(llvm_paths_py, "w") as py_File: + py_File.write("LLVM_DST_ROOT = " + '"' + LLVM_DST_ROOT + '"\n') + py_File.write("LLVM_SRC_ROOT = " + '"' + LLVM_SRC_ROOT + '"\n') + py_File.write("LLVM_GXX_BIN_DIR = " + '"' + LLVM_GXX_BIN_DIR + '"\n') # write unix make config file for llvm paths - make_File = open(llvm_paths_make, 'w') - make_File.write("LLVM_DST_ROOT = " + LLVM_DST_ROOT + '\n') - make_File.write("LLVM_SRC_ROOT = " + LLVM_SRC_ROOT + '\n') - make_File.write("LLVM_GXX_BIN_DIR = " + LLVM_GXX_BIN_DIR + '\n') - make_File.close() + with open(llvm_paths_make, 'w') as make_File: + make_File.write("LLVM_DST_ROOT = " + LLVM_DST_ROOT + '\n') + make_File.write("LLVM_SRC_ROOT = " + LLVM_SRC_ROOT + '\n') + make_File.write("LLVM_GXX_BIN_DIR = " + LLVM_GXX_BIN_DIR + '\n') # build @@ -209,7 +191,7 @@ def build(): def runTests(): - print("Running all regression tests:\n") + print("Running all standard regression tests (use --all_ml separately for ML/ONNX tests):\n") LLFI_BUILD_DIR = os.path.realpath(required_configs['LLFI_BUILD_ROOT']) subprocess.call(["python3", LLFI_BUILD_DIR + "/test_suite/SCRIPTS/llfi_test", "--all", "--threads", "2", "--verbose"]) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..114208aa --- /dev/null +++ b/setup.cfg @@ -0,0 +1,78 @@ +# flake8 configuration for LLTFI +# +# Install: pip install flake8 flake8-bugbear +# Run: flake8 bin/ test_suite/SCRIPTS/ tools/ +# Or: lint.sh (runs flake8 automatically) + +[flake8] +# 120 chars matches the practical line length used in existing scripts. +max-line-length = 120 + +# Files / directories to skip +exclude = + # Build artifacts + build/, + __pycache__/, + *.egg-info/, + # Third-party vendored code + tools/zgrviewer/ + +# Ignored error codes +extend-ignore = + # E203: whitespace before ':' — conflicts with slice notation + E203, + # W503: line break before binary operator — PEP 8 now recommends this style + W503, + # E501: line too long — already handled by max-line-length above + E501, + # ----------------------------------------------------------------------- + # Pre-existing whitespace / formatting debt across all legacy scripts. + # These are suppressed globally because touching every legacy file for + # pure style would produce an enormous, review-unfriendly diff. + # New files must follow CODING_GUIDELINES.md (4-space indent, PEP 8). + # TODO: reformat legacy files incrementally and remove these ignores. + # ----------------------------------------------------------------------- + # Indentation + E111, E114, E117, + # Blank lines around functions / classes + E301, E302, E303, E305, E306, + # Whitespace around operators and after punctuation + E225, E231, E241, E251, + # Whitespace inside brackets / before inline comment + E201, E202, E261, E262, + # Comment formatting (block comments not starting with '# ') + E265, E266, + # Trailing whitespace on lines and blank lines + W291, W293, + # Continuation line indentation + E121, E122, E123, E124, E125, E126, E127, E128, E131, + # Multiple imports on one line — pre-existing in batch scripts + E401, + # Tab indentation and mixed tabs/spaces — pre-existing in batch/test helpers + W191, E101, + # Misc whitespace patterns widespread in legacy scripts + E211, E221, E222, E272, E275, E502, + # Multiple statements on one line — pre-existing + E701, + # Visually indented line with same indent as next logical line + E129 + +# Per-file ignores +per-file-ignores = + # bin/ and test scripts use sys.path.append() before local imports + bin/*.py: E402 + test_suite/SCRIPTS/*.py: E402 + +# flake8-bugbear (B*) checks — enforces several coding guidelines: +# B001: bare except (guidelines: never use bare except:) +# B006: mutable default arguments +# B007: loop variable unused +# B009/B010: use getattr/setattr instead of direct access +# B011: assert False should be raise AssertionError +# B023: function definition in loop +# B904: raise ... from err in except blocks +extend-select = B + +# Specifically require these bugbear rules that map to our guidelines: +# B001 — no bare except: +# B602 — no subprocess(shell=True) [requires flake8-bugbear >= 22.x] diff --git a/test_suite/MakefileGeneration/normal_IR/image.h b/test_suite/MakefileGeneration/normal_IR/image.h index 27fc3e0b..2547dcdb 100644 --- a/test_suite/MakefileGeneration/normal_IR/image.h +++ b/test_suite/MakefileGeneration/normal_IR/image.h @@ -6,8 +6,7 @@ *cr ***************************************************************************/ -struct image_i16 -{ +struct image_i16 { int width; int height; short *data; @@ -17,7 +16,7 @@ struct image_i16 extern "C" { #endif -struct image_i16 * load_image(char *filename); +struct image_i16 *load_image(char *filename); void free_image(struct image_i16 *); #ifdef __cplusplus diff --git a/test_suite/MakefileGeneration/normal_IR/parboil.h b/test_suite/MakefileGeneration/normal_IR/parboil.h index 9885c9dc..f1f16283 100644 --- a/test_suite/MakefileGeneration/normal_IR/parboil.h +++ b/test_suite/MakefileGeneration/normal_IR/parboil.h @@ -10,20 +10,20 @@ extern "C" { /* Command line parameters for benchmarks */ struct pb_Parameters { - char *outFile; /* If not NULL, the raw output of the - * computation should be saved to this - * file. The string is owned. */ - char **inpFiles; /* A NULL-terminated array of strings - * holding the input file(s) for the - * computation. The array and strings - * are owned. */ - int synchronizeGpu; /* Controls behavior of CUDA benchmarks. - * If nonzero, a CUDA runtime - * synchronization call should happen - * after each data transfer to the GPU - * and after each kernel call. This - * is necessary for accurate timing - * measurement. */ + char *outFile; /* If not NULL, the raw output of the + * computation should be saved to this + * file. The string is owned. */ + char **inpFiles; /* A NULL-terminated array of strings + * holding the input file(s) for the + * computation. The array and strings + * are owned. */ + int synchronizeGpu; /* Controls behavior of CUDA benchmarks. + * If nonzero, a CUDA runtime + * synchronization call should happen + * after each data transfer to the GPU + * and after each kernel call. This + * is necessary for accurate timing + * measurement. */ }; /* Read command-line parameters. @@ -35,24 +35,21 @@ struct pb_Parameters { * If there is an error, then an error message is printed on stderr * and NULL is returned. */ -struct pb_Parameters * -pb_ReadParameters(int *_argc, char **argv); +struct pb_Parameters *pb_ReadParameters(int *_argc, char **argv); /* Free an instance of struct pb_Parameters. */ -void -pb_FreeParameters(struct pb_Parameters *p); +void pb_FreeParameters(struct pb_Parameters *p); /* Count the number of input files in a pb_Parameters instance. */ -int -pb_Parameters_CountInputs(struct pb_Parameters *p); +int pb_Parameters_CountInputs(struct pb_Parameters *p); /* A time or duration. */ #if _POSIX_VERSION >= 200112L typedef unsigned long long pb_Timestamp; /* time in microseconds */ #else -# error "Timestamps not implemented" +#error "Timestamps not implemented" #endif enum pb_TimerState { @@ -62,49 +59,45 @@ enum pb_TimerState { struct pb_Timer { enum pb_TimerState state; - pb_Timestamp elapsed; /* Amount of time elapsed so far */ - pb_Timestamp init; /* Beginning of the current time interval, - * if state is RUNNING. Undefined - * otherwise. */ + pb_Timestamp elapsed; /* Amount of time elapsed so far */ + pb_Timestamp init; /* Beginning of the current time interval, + * if state is RUNNING. Undefined + * otherwise. */ }; /* Reset a timer. * Use this to initialize a timer or to clear * its elapsed time. The reset timer is stopped. */ -void -pb_ResetTimer(struct pb_Timer *timer); +void pb_ResetTimer(struct pb_Timer *timer); /* Start a timer. The timer is set to RUNNING mode and * time elapsed while the timer is running is added to * the timer. * The timer should not already be running. */ -void -pb_StartTimer(struct pb_Timer *timer); +void pb_StartTimer(struct pb_Timer *timer); /* Stop a timer. * This stops adding elapsed time to the timer. * The timer should not already be stopped. */ -void -pb_StopTimer(struct pb_Timer *timer); +void pb_StopTimer(struct pb_Timer *timer); /* Get the elapsed time in seconds. */ -double -pb_GetElapsedTime(struct pb_Timer *timer); +double pb_GetElapsedTime(struct pb_Timer *timer); /* Execution time is assigned to one of these categories. */ enum pb_TimerID { pb_TimerID_NONE = 0, - pb_TimerID_IO, /* Time spent in input/output */ - pb_TimerID_GPU, /* Time spent computing on the GPU */ - pb_TimerID_COPY, /* Time spent moving data to/from GPU and - * allocating/freeing memory on the GPU */ - pb_TimerID_COMPUTE, /* Time for all program execution other - * than parsing command line arguments, - * I/O, GPU, and copy */ - pb_TimerID_LAST /* Number of timer IDs */ + pb_TimerID_IO, /* Time spent in input/output */ + pb_TimerID_GPU, /* Time spent computing on the GPU */ + pb_TimerID_COPY, /* Time spent moving data to/from GPU and + * allocating/freeing memory on the GPU */ + pb_TimerID_COMPUTE, /* Time for all program execution other + * than parsing command line arguments, + * I/O, GPU, and copy */ + pb_TimerID_LAST /* Number of timer IDs */ }; /* A set of timers for recording execution times. */ @@ -114,18 +107,15 @@ struct pb_TimerSet { }; /* Reset all timers in the set. */ -void -pb_InitializeTimerSet(struct pb_TimerSet *timers); +void pb_InitializeTimerSet(struct pb_TimerSet *timers); /* Select which timer the next interval of time should be accounted * to. The selected timer is started and other timers are stopped. * Using pb_TimerID_NONE stops all timers. */ -void -pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); +void pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); /* Print timer values to standard output. */ -void -pb_PrintTimerSet(struct pb_TimerSet *timers); +void pb_PrintTimerSet(struct pb_TimerSet *timers); #ifdef __cplusplus } diff --git a/test_suite/MakefileGeneration/normal_IR/sad.h b/test_suite/MakefileGeneration/normal_IR/sad.h index bfd8017f..c8a247ee 100644 --- a/test_suite/MakefileGeneration/normal_IR/sad.h +++ b/test_suite/MakefileGeneration/normal_IR/sad.h @@ -10,7 +10,7 @@ #define SEARCH_RANGE 16 /* The total search area is 33 pixels square */ -#define SEARCH_DIMENSION (2*SEARCH_RANGE+1) +#define SEARCH_DIMENSION (2 * SEARCH_RANGE + 1) /* The total number of search positions is 33^2 */ #define MAX_POS 1089 @@ -18,7 +18,7 @@ /* This is padded to a multiple of 4 when allocating memory */ #define MAX_POS_PADDED 1092 -/* VBSME block indices in the SAD array for different +/* VBSME block indices in the SAD array for different * block sizes. The index is computed from the * image size in macroblocks. Block sizes are (height, width): * 1: 16 by 16 pixels, one block per macroblock @@ -30,22 +30,27 @@ * 7: 4 by 4 pixels, 16 blocks per macroblock */ #define SAD_TYPE_1_IX(image_size) 0 -#define SAD_TYPE_2_IX(image_size) ((image_size)*MAX_POS_PADDED) -#define SAD_TYPE_3_IX(image_size) ((image_size)*(3*MAX_POS_PADDED)) -#define SAD_TYPE_4_IX(image_size) ((image_size)*(5*MAX_POS_PADDED)) -#define SAD_TYPE_5_IX(image_size) ((image_size)*(9*MAX_POS_PADDED)) -#define SAD_TYPE_6_IX(image_size) ((image_size)*(17*MAX_POS_PADDED)) -#define SAD_TYPE_7_IX(image_size) ((image_size)*(25*MAX_POS_PADDED)) +#define SAD_TYPE_2_IX(image_size) ((image_size) * MAX_POS_PADDED) +#define SAD_TYPE_3_IX(image_size) ((image_size) * (3 * MAX_POS_PADDED)) +#define SAD_TYPE_4_IX(image_size) ((image_size) * (5 * MAX_POS_PADDED)) +#define SAD_TYPE_5_IX(image_size) ((image_size) * (9 * MAX_POS_PADDED)) +#define SAD_TYPE_6_IX(image_size) ((image_size) * (17 * MAX_POS_PADDED)) +#define SAD_TYPE_7_IX(image_size) ((image_size) * (25 * MAX_POS_PADDED)) -#define SAD_TYPE_IX(n, image_size) \ - ((n == 1) ? SAD_TYPE_1_IX(image_size) : \ - ((n == 2) ? SAD_TYPE_2_IX(image_size) : \ - ((n == 3) ? SAD_TYPE_3_IX(image_size) : \ - ((n == 4) ? SAD_TYPE_4_IX(image_size) : \ - ((n == 5) ? SAD_TYPE_5_IX(image_size) : \ - ((n == 6) ? SAD_TYPE_6_IX(image_size) : \ - (SAD_TYPE_7_IX(image_size) \ - ))))))) +#define SAD_TYPE_IX(n, image_size) \ + ((n == 1) \ + ? SAD_TYPE_1_IX(image_size) \ + : ((n == 2) \ + ? SAD_TYPE_2_IX(image_size) \ + : ((n == 3) \ + ? SAD_TYPE_3_IX(image_size) \ + : ((n == 4) \ + ? SAD_TYPE_4_IX(image_size) \ + : ((n == 5) \ + ? SAD_TYPE_5_IX(image_size) \ + : ((n == 6) \ + ? SAD_TYPE_6_IX(image_size) \ + : (SAD_TYPE_7_IX(image_size)))))))) #define SAD_TYPE_1_CT 1 #define SAD_TYPE_2_CT 2 @@ -55,28 +60,27 @@ #define SAD_TYPE_6_CT 8 #define SAD_TYPE_7_CT 16 -#define SAD_TYPE_CT(n) \ - ((n == 1) ? SAD_TYPE_1_CT : \ - ((n == 2) ? SAD_TYPE_2_CT : \ - ((n == 3) ? SAD_TYPE_3_CT : \ - ((n == 4) ? SAD_TYPE_4_CT : \ - ((n == 5) ? SAD_TYPE_5_CT : \ - ((n == 6) ? SAD_TYPE_6_CT : \ - (SAD_TYPE_7_CT \ - ))))))) +#define SAD_TYPE_CT(n) \ + ((n == 1) \ + ? SAD_TYPE_1_CT \ + : ((n == 2) \ + ? SAD_TYPE_2_CT \ + : ((n == 3) \ + ? SAD_TYPE_3_CT \ + : ((n == 4) \ + ? SAD_TYPE_4_CT \ + : ((n == 5) ? SAD_TYPE_5_CT \ + : ((n == 6) ? SAD_TYPE_6_CT \ + : (SAD_TYPE_7_CT))))))) #ifdef __cplusplus extern "C" { #endif -void sad4_cpu(unsigned short *blk_sad, - unsigned short *frame, - unsigned short *ref, - int mb_width, - int mb_height); +void sad4_cpu(unsigned short *blk_sad, unsigned short *frame, + unsigned short *ref, int mb_width, int mb_height); -void larger_sads(unsigned short *sads, - int mbs); +void larger_sads(unsigned short *sads, int mbs); #ifdef __cplusplus } diff --git a/test_suite/MakefileGeneration/readable_IR/defines.h b/test_suite/MakefileGeneration/readable_IR/defines.h index 943c5986..2a7f3e57 100644 --- a/test_suite/MakefileGeneration/readable_IR/defines.h +++ b/test_suite/MakefileGeneration/readable_IR/defines.h @@ -1,7 +1,7 @@ /************************************************************************** DEFINES.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,15 +11,13 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Thu Feb 17 21:46:48 2005 by Andreas Loebel (boss.local.de) */ /* $Id: defines.h,v 1.13 2005/02/17 21:43:12 bzfloebe Exp $ */ - - #ifndef _DEFINES_H #define _DEFINES_H @@ -27,104 +25,87 @@ Copyright (c) 2003-2005 Andreas Loebel. #ifndef _WIN32 #include #endif +#include +#include +#include +#include #include #include -#include #include -#include -#include -#include #ifdef INTERNAL_TIMING -#include -#include #include +#include +#include #endif - #include "prototyp.h" - -#define UNBOUNDED 1000000000 -#define ZERO 0 -#define MAX_ART_COST (long)(100000000L) +#define UNBOUNDED 1000000000 +#define ZERO 0 +#define MAX_ART_COST (long)(100000000L) #define ARITHMETIC_TYPE "I" - - -#define FIXED -1 -#define BASIC 0 -#define AT_LOWER 1 -#define AT_UPPER 2 +#define FIXED -1 +#define BASIC 0 +#define AT_LOWER 1 +#define AT_UPPER 2 /* #define AT_ZERO 3 NOT ALLOWED FOR THE SPEC VERSION */ #undef AT_ZERO - -#define UP 1 -#define DOWN 0 - - +#define UP 1 +#define DOWN 0 typedef long flow_t; typedef long cost_t; - - - #ifndef NULL #define NULL 0 #endif - #ifndef ABS -#define ABS( x ) ( ((x) >= 0) ? ( x ) : -( x ) ) +#define ABS(x) (((x) >= 0) ? (x) : -(x)) #endif - #ifndef MAX -#define MAX(a,b) (((a) > (b)) ? (a) : (b)) +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) #endif - #ifndef SET_ZERO -#define SET_ZERO( vec, n ) if( vec ) memset( (void *)vec, 0, (size_t)n ) +#define SET_ZERO(vec, n) \ + if (vec) \ + memset((void *)vec, 0, (size_t)n) #endif - #ifndef FREE -#define FREE( vec ) if( vec ) free( (void *)vec ) +#define FREE(vec) \ + if (vec) \ + free((void *)vec) #endif - typedef struct node node_t; typedef struct node *node_p; typedef struct arc arc_t; typedef struct arc *arc_p; - - -struct node -{ - cost_t potential; +struct node { + cost_t potential; int orientation; node_p child; node_p pred; node_p sibling; - node_p sibling_prev; - arc_p basic_arc; + node_p sibling_prev; + arc_p basic_arc; arc_p firstout, firstin; arc_p arc_tmp; flow_t flow; - long depth; + long depth; int number; int time; }; - - -struct arc -{ +struct arc { cost_t cost; node_p tail, head; int ident; @@ -133,16 +114,13 @@ struct arc cost_t org_cost; }; - - -typedef struct network -{ +typedef struct network { char inputfile[200]; char clustfile[200]; long n, n_trips; long max_m, m, m_org, m_impl; long max_residual_new_m, max_new_m; - + long primal_unbounded; long dual_unbounded; long perturbed; @@ -152,15 +130,14 @@ typedef struct network long feas_tol; long pert_val; long bigM; - double optcost; + double optcost; cost_t ignore_impl; node_p nodes, stop_nodes; arc_p arcs, stop_arcs; - arc_p dummy_arcs, stop_dummy; + arc_p dummy_arcs, stop_dummy; long iterations; long bound_exchanges; long checksum; } network_t; - #endif diff --git a/test_suite/MakefileGeneration/readable_IR/implicit.h b/test_suite/MakefileGeneration/readable_IR/implicit.h index a6b35b49..9478c69b 100644 --- a/test_suite/MakefileGeneration/readable_IR/implicit.h +++ b/test_suite/MakefileGeneration/readable_IR/implicit.h @@ -1,7 +1,7 @@ /************************************************************************** IMPLICIT.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,24 +11,20 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:18 2004 by Andreas Loebel (boss.local.de) */ /* $Id: implicit.h,v 1.11 2005/02/17 19:42:21 bzfloebe Exp $ */ - #ifndef _IMPLICIT_H #define _IMPLICIT_H - -#include "mcfutil.h" #include "mcflimit.h" +#include "mcfutil.h" - -extern long price_out_impl _PROTO_(( network_t * )); -extern long suspend_impl _PROTO_(( network_t *, cost_t, long )); - +extern long price_out_impl _PROTO_((network_t *)); +extern long suspend_impl _PROTO_((network_t *, cost_t, long)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/mcf.h b/test_suite/MakefileGeneration/readable_IR/mcf.h index 01e8841b..526824ea 100644 --- a/test_suite/MakefileGeneration/readable_IR/mcf.h +++ b/test_suite/MakefileGeneration/readable_IR/mcf.h @@ -1,7 +1,7 @@ /************************************************************************** MCF.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,28 +11,24 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:29 2004 by Andreas Loebel (boss.local.de) */ /* $Id: mcf.h,v 1.9 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _MCF_H #define _MCF_H - #include "defines.h" -#include "mcfutil.h" -#include "readmin.h" -#include "output.h" -#include "pstart.h" -#include "psimplex.h" -#include "pbeampp.h" #include "implicit.h" #include "limits.h" - +#include "mcfutil.h" +#include "output.h" +#include "pbeampp.h" +#include "psimplex.h" +#include "pstart.h" +#include "readmin.h" #endif diff --git a/test_suite/MakefileGeneration/readable_IR/mcflimit.h b/test_suite/MakefileGeneration/readable_IR/mcflimit.h index 9a831c55..ade762b4 100644 --- a/test_suite/MakefileGeneration/readable_IR/mcflimit.h +++ b/test_suite/MakefileGeneration/readable_IR/mcflimit.h @@ -1,7 +1,7 @@ /************************************************************************** MCFLIMIT.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,18 +11,16 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Thu Feb 17 22:24:36 2005 by Andreas Loebel (boss.local.de) */ /* $Id: mcflimit.h,v 1.12 2005/02/17 21:43:12 bzfloebe Exp $ */ - #ifndef _MCF_LIMITS_H #define _MCF_LIMITS_H - #define BIGM 1.0e7 #define STRECHT(x) ((long)(1.25 * (double)(x))) @@ -31,9 +29,8 @@ Copyright (c) 2003-2005 Andreas Loebel. #define MAX_NEW_ARCS_SMALL_NET 3000000 #define MAX_NEW_ARCS_LARGE_NET 28900000 -#define MAX_NB_ITERATIONS_SMALL_NET 5 -#define MAX_NB_ITERATIONS_LARGE_NET 5 - +#define MAX_NB_ITERATIONS_SMALL_NET 5 +#define MAX_NB_ITERATIONS_LARGE_NET 5 /* // Some operating systems and compiler, respectively, do not handle reallocs @@ -42,5 +39,4 @@ Copyright (c) 2003-2005 Andreas Loebel. */ #define SPEC_STATIC - #endif diff --git a/test_suite/MakefileGeneration/readable_IR/mcfutil.h b/test_suite/MakefileGeneration/readable_IR/mcfutil.h index 5ae24855..678525e4 100644 --- a/test_suite/MakefileGeneration/readable_IR/mcfutil.h +++ b/test_suite/MakefileGeneration/readable_IR/mcfutil.h @@ -1,7 +1,7 @@ /************************************************************************** MCFUTIL.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,29 +11,24 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:49 2004 by Andreas Loebel (boss.local.de) */ /* $Id: mcfutil.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _MCFUTIL_H #define _MCFUTIL_H - #include "defines.h" - -extern void refresh_neighbour_lists _PROTO_(( network_t * )); -extern long refresh_potential _PROTO_(( network_t * )); -extern double flow_cost _PROTO_(( network_t * )); -extern double flow_org_cost _PROTO_(( network_t * )); -extern long primal_feasible _PROTO_(( network_t * )); -extern long dual_feasible _PROTO_(( network_t * )); -extern long getfree _PROTO_(( network_t * )); - +extern void refresh_neighbour_lists _PROTO_((network_t *)); +extern long refresh_potential _PROTO_((network_t *)); +extern double flow_cost _PROTO_((network_t *)); +extern double flow_org_cost _PROTO_((network_t *)); +extern long primal_feasible _PROTO_((network_t *)); +extern long dual_feasible _PROTO_((network_t *)); +extern long getfree _PROTO_((network_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/output.h b/test_suite/MakefileGeneration/readable_IR/output.h index 323ea485..a0b51286 100644 --- a/test_suite/MakefileGeneration/readable_IR/output.h +++ b/test_suite/MakefileGeneration/readable_IR/output.h @@ -1,7 +1,7 @@ /************************************************************************** OUTPUT.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:59 2004 by Andreas Loebel (boss.local.de) */ /* $Id: output.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _OUTPUT_H #define _OUTPUT_H - #include "mcfutil.h" - -extern long write_circulations _PROTO_(( char *, network_t * )); - +extern long write_circulations _PROTO_((char *, network_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/pbeampp.h b/test_suite/MakefileGeneration/readable_IR/pbeampp.h index 834acd2d..03957b13 100644 --- a/test_suite/MakefileGeneration/readable_IR/pbeampp.h +++ b/test_suite/MakefileGeneration/readable_IR/pbeampp.h @@ -1,7 +1,7 @@ /************************************************************************** PBEAMPP.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:09 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pbeampp.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PBEAMPP_H #define _PBEAMPP_H - #include "defines.h" - -extern arc_t *primal_bea_mpp _PROTO_(( long, arc_t*, arc_t*, cost_t* )); - +extern arc_t *primal_bea_mpp _PROTO_((long, arc_t *, arc_t *, cost_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/pbla.h b/test_suite/MakefileGeneration/readable_IR/pbla.h index 70fdbfd8..f062452b 100644 --- a/test_suite/MakefileGeneration/readable_IR/pbla.h +++ b/test_suite/MakefileGeneration/readable_IR/pbla.h @@ -1,7 +1,7 @@ /************************************************************************** PBLA.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,24 +11,19 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:19 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pbla.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PBLA_H #define _PBLA_H - #include "defines.h" - -extern node_t *primal_iminus _PROTO_(( flow_t *, long *, node_t *, - node_t *, node_t ** )); - +extern node_t *primal_iminus _PROTO_((flow_t *, long *, node_t *, node_t *, + node_t **)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/pflowup.h b/test_suite/MakefileGeneration/readable_IR/pflowup.h index 9f19c3cf..6209509f 100644 --- a/test_suite/MakefileGeneration/readable_IR/pflowup.h +++ b/test_suite/MakefileGeneration/readable_IR/pflowup.h @@ -1,7 +1,7 @@ /************************************************************************** PFLOWUP.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:28 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pflowup.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PFLOWUP_H #define _PFLOWUP_H - #include "defines.h" - -extern void primal_update_flow _PROTO_(( node_t *, node_t *, node_t * )); - +extern void primal_update_flow _PROTO_((node_t *, node_t *, node_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/prototyp.h b/test_suite/MakefileGeneration/readable_IR/prototyp.h index d791d85e..eccd4173 100644 --- a/test_suite/MakefileGeneration/readable_IR/prototyp.h +++ b/test_suite/MakefileGeneration/readable_IR/prototyp.h @@ -1,7 +1,7 @@ /************************************************************************** PROTOTYPE.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,34 +11,23 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Wed Feb 16 20:24:10 2005 by Andreas Loebel (boss.local.de) */ /* $Id: prototyp.h,v 1.11 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PROTOTYP_H #define _PROTOTYP_H - #ifndef _PROTO_ -#if defined(__STDC__) || defined(__cplusplus) \ - || defined(WANT_STDC_PROTO) || defined(SPEC_CPU) -#define _PROTO_( args ) args +#if defined(__STDC__) || defined(__cplusplus) || defined(WANT_STDC_PROTO) || \ + defined(SPEC_CPU) +#define _PROTO_(args) args #else -#define _PROTO_( args ) +#define _PROTO_(args) #endif #endif - #endif - - - - - - - diff --git a/test_suite/MakefileGeneration/readable_IR/psimplex.h b/test_suite/MakefileGeneration/readable_IR/psimplex.h index 9562489b..27e983ae 100644 --- a/test_suite/MakefileGeneration/readable_IR/psimplex.h +++ b/test_suite/MakefileGeneration/readable_IR/psimplex.h @@ -1,7 +1,7 @@ /************************************************************************** PSIMPLEX.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,28 +11,23 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:48 2004 by Andreas Loebel (boss.local.de) */ /* $Id: psimplex.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PSIMPLEX_H #define _PSIMPLEX_H - #include "defines.h" +#include "mcfutil.h" #include "pbeampp.h" #include "pbla.h" #include "pflowup.h" #include "treeup.h" -#include "mcfutil.h" - - -extern long primal_net_simplex _PROTO_(( network_t * )); +extern long primal_net_simplex _PROTO_((network_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/pstart.h b/test_suite/MakefileGeneration/readable_IR/pstart.h index becb4430..b0bc60b3 100644 --- a/test_suite/MakefileGeneration/readable_IR/pstart.h +++ b/test_suite/MakefileGeneration/readable_IR/pstart.h @@ -1,7 +1,7 @@ /************************************************************************** PSTART.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:58 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pstart.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PSTART_H #define _PSTART_H - #include "defines.h" - -extern long primal_start_artificial _PROTO_(( network_t * )); - +extern long primal_start_artificial _PROTO_((network_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/readmin.h b/test_suite/MakefileGeneration/readable_IR/readmin.h index 602efbcd..b7f9a661 100644 --- a/test_suite/MakefileGeneration/readable_IR/readmin.h +++ b/test_suite/MakefileGeneration/readable_IR/readmin.h @@ -1,7 +1,7 @@ /************************************************************************** READMIN.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,25 +11,20 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:23:07 2004 by Andreas Loebel (boss.local.de) */ /* $Id: readmin.h,v 1.11 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _READMIN_H #define _READMIN_H - #include "defines.h" -#include "mcfutil.h" #include "mcflimit.h" +#include "mcfutil.h" - -extern long read_min _PROTO_(( network_t * )); - +extern long read_min _PROTO_((network_t *)); #endif diff --git a/test_suite/MakefileGeneration/readable_IR/treeup.h b/test_suite/MakefileGeneration/readable_IR/treeup.h index c7fed7b3..0ca13ee0 100644 --- a/test_suite/MakefileGeneration/readable_IR/treeup.h +++ b/test_suite/MakefileGeneration/readable_IR/treeup.h @@ -1,7 +1,7 @@ /************************************************************************** TREEUP.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,25 +11,20 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:23:16 2004 by Andreas Loebel (boss.local.de) */ /* $Id: treeup.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _TREEUP_H #define _TREEUP_H - #include "defines.h" - -extern void update_tree _PROTO_(( long, long, flow_t, flow_t, node_t *, - node_t *, node_t *, node_t *, node_t *, - arc_t *, cost_t, flow_t )); - +extern void update_tree _PROTO_((long, long, flow_t, flow_t, node_t *, node_t *, + node_t *, node_t *, node_t *, arc_t *, cost_t, + flow_t)); #endif diff --git a/test_suite/PROGRAMS/bfs/main.cpp b/test_suite/PROGRAMS/bfs/main.cpp old mode 100755 new mode 100644 index e16a3441..4c17878f --- a/test_suite/PROGRAMS/bfs/main.cpp +++ b/test_suite/PROGRAMS/bfs/main.cpp @@ -1,212 +1,206 @@ -/*************************************************************************** - *cr - *cr (C) Copyright 2007 The Board of Trustees of the - *cr University of Illinois - *cr All Rights Reserved - *cr - ***************************************************************************/ -/* - Implementing Breadth first search on CUDA using algorithm given in DAC'10 - paper "An Effective GPU Implementation of Breadth-First Search" - - Copyright (c) 2010 University of Illinois at Urbana-Champaign. - All rights reserved. - - Permission to use, copy, modify and distribute this software and its documentation for - educational purpose is hereby granted without fee, provided that the above copyright - notice and this permission notice appear in all copies of this software and that you do - not sell the software. - - THE SOFTWARE IS PROVIDED "AS IS" AND WITHOUT WARRANTY OF ANY KIND,EXPRESS, IMPLIED OR - OTHERWISE. - - Author: Lijiuan Luo (lluo3@uiuc.edu) -*/ -#include -#include -#include -#include -#include "parboil.h" -#include -#include - -#define MAX_THREADS_PER_BLOCK 512 -#define NUM_SM 30//the number of Streaming Multiprocessors; may change in the future archs -#define NUM_SP 16//8//the number of Streaming processors within each SM; may change in the future - //architectures -#define EXP 4//3// EXP = log(NUM_SP), assuming NUM_SP is still power of 2 in the future architecture - //using EXP and shifting can speed up division operation -#define MOD_OP 8//7 // This variable is also related with NUM_SP; may change in the future architecture; - //using MOD_OP and "bitwise and" can speed up mod operation -#define INF 2147483647//2^31-1 - -#define UP_LIMIT 16677216//2^24 -#define WHITE 16677217 -#define GRAY 16677218 -#define GRAY0 16677219 -#define GRAY1 16677220 -#define BLACK 16677221 -int no_of_nodes; //the number of nodes in the graph -int edge_list_size;//the number of edges in the graph -FILE *fp; - -//typedef int2 Node; -//typedef int2 Edge; - -struct Node{ - int x; - int y; -}; - -struct Edge{ - int x; - int y; -}; -//Somehow "cudaMemset" does not work. So I use cudaMemcpy of constant variables for initialization -const int h_top = 1; -const int zero = 0; - -void runCPU(int argc, char** argv); -void runGPU(int argc, char** argv); -//////////////////////////////////////////////////////////////////// -//the cpu version of bfs for speed comparison -//the text book version ("Introduction to Algorithms") -//////////////////////////////////////////////////////////////////// -void BFS_CPU( Node * h_graph_nodes,Edge * h_graph_edges, - int * color, int * h_cost, int source){ - std::deque wavefront; - wavefront.push_back(source); - color[source] = GRAY; - int index; - while(!wavefront.empty()){ - index = wavefront.front(); - wavefront.pop_front(); - for(int i=h_graph_nodes[index].x; - i<(h_graph_nodes[index].y + - h_graph_nodes[index].x); i++) - { - int id = h_graph_edges[i].x; - if(color[id] == WHITE){ - h_cost[id]=h_cost[index]+1; - wavefront.push_back(id); - color[id] = GRAY; - } - } - color[index] = BLACK; - - - } - -} -//////////////////////////////////////////////////////////////////////////////// -// Main Program -//////////////////////////////////////////////////////////////////////////////// -int main( int argc, char** argv) -{ - no_of_nodes=0; - edge_list_size=0; - runCPU(argc,argv); -// if( cutCheckCmdLineFlag(argc, (const char**)argv, "device") ) -// cutilDeviceInit(argc, argv); -// else - //cudaSetDevice( cutGetMaxGflopsDeviceId() ); -// cudaSetDevice( 1); - - - //CUT_EXIT(argc, argv); -} -/////////////////////////////// -//FUNCTION: only run CPU version -//////////////////////////////////////////// -void runCPU( int argc, char** argv) -{ - - struct pb_Parameters *params; - struct pb_TimerSet timers; - - pb_InitializeTimerSet(&timers); - params = pb_ReadParameters(&argc, argv); - if ((params->inpFiles[0] == NULL) || (params->inpFiles[1] != NULL)) - { - fprintf(stderr, "Expecting one input filename\n"); - exit(-1); - } - - pb_SwitchToTimer(&timers, pb_TimerID_IO); - //printf("Reading File\n"); - //Read in Graph from a file - fp = fopen(params->inpFiles[0],"r"); - if(!fp) - { - printf("Error Reading graph file\n"); - return; - } - - int source; - - fscanf(fp,"%d",&no_of_nodes); - // allocate host memory - Node* h_graph_nodes = (Node*) malloc(sizeof(Node)*no_of_nodes); - int *color = (int*) malloc(sizeof(int)*no_of_nodes); - int start, edgeno; - // initalize the memory - for( unsigned int i = 0; i < no_of_nodes; i++) - { - fscanf(fp,"%d %d",&start,&edgeno); - h_graph_nodes[i].x = start; - h_graph_nodes[i].y = edgeno; - color[i]=WHITE; - } - //read the source node from the file - fscanf(fp,"%d",&source); - fscanf(fp,"%d",&edge_list_size); - int id,cost; - Edge* h_graph_edges = (Edge*) malloc(sizeof(Edge)*edge_list_size); - for(int i=0; i < edge_list_size ; i++) - { - fscanf(fp,"%d",&id); - fscanf(fp,"%d",&cost); - h_graph_edges[i].x = id; - h_graph_edges[i].y = cost; - } - if(fp) - fclose(fp); - - //printf("Read File\n"); - - // allocate mem for the result on host side - int* h_cost = (int*) malloc( sizeof(int)*no_of_nodes); - for(int i = 0; i < no_of_nodes; i++){ - h_cost[i] = INF; - } - h_cost[source] = 0; - //printf("start cpu version\n"); - unsigned int cpu_timer = 0; - pb_SwitchToTimer(&timers, pb_TimerID_COMPUTE); - BFS_CPU( h_graph_nodes, h_graph_edges, color, h_cost, source - ); - pb_SwitchToTimer(&timers, pb_TimerID_IO); - if(params->outFile!=NULL) - { - //printf("Result stored in %s\n", params->outFile); - FILE *fp = fopen(params->outFile,"w"); - fprintf(fp,"%d\n", no_of_nodes); - for(int i=0;i +#include +#include +#include +#include +#include + +#include "parboil.h" + +#define MAX_THREADS_PER_BLOCK 512 +#define NUM_SM \ + 30 // the number of Streaming Multiprocessors; may change in the future archs +#define NUM_SP \ + 16 // 8//the number of Streaming processors within each SM; may change in the + // future +// architectures +#define EXP \ + 4 // 3// EXP = log(NUM_SP), assuming NUM_SP is still power of 2 in the future + // architecture +// using EXP and shifting can speed up division operation +#define MOD_OP \ + 8 // 7 // This variable is also related with NUM_SP; may change in the future + // architecture; +// using MOD_OP and "bitwise and" can speed up mod operation +#define INF 2147483647 // 2^31-1 + +#define UP_LIMIT 16677216 // 2^24 +#define WHITE 16677217 +#define GRAY 16677218 +#define GRAY0 16677219 +#define GRAY1 16677220 +#define BLACK 16677221 +int no_of_nodes; // the number of nodes in the graph +int edge_list_size; // the number of edges in the graph +FILE *fp; + +// typedef int2 Node; +// typedef int2 Edge; + +struct Node { + int x; + int y; +}; + +struct Edge { + int x; + int y; +}; +// Somehow "cudaMemset" does not work. So I use cudaMemcpy of constant variables +// for initialization +const int h_top = 1; +const int zero = 0; + +void runCPU(int argc, char **argv); +void runGPU(int argc, char **argv); +//////////////////////////////////////////////////////////////////// +// the cpu version of bfs for speed comparison +// the text book version ("Introduction to Algorithms") +//////////////////////////////////////////////////////////////////// +void BFS_CPU(Node *h_graph_nodes, Edge *h_graph_edges, int *color, int *h_cost, + int source) { + std::deque wavefront; + wavefront.push_back(source); + color[source] = GRAY; + int index; + while (!wavefront.empty()) { + index = wavefront.front(); + wavefront.pop_front(); + for (int i = h_graph_nodes[index].x; + i < (h_graph_nodes[index].y + h_graph_nodes[index].x); i++) { + int id = h_graph_edges[i].x; + if (color[id] == WHITE) { + h_cost[id] = h_cost[index] + 1; + wavefront.push_back(id); + color[id] = GRAY; + } + } + color[index] = BLACK; + } +} +//////////////////////////////////////////////////////////////////////////////// +// Main Program +//////////////////////////////////////////////////////////////////////////////// +int main(int argc, char **argv) { + no_of_nodes = 0; + edge_list_size = 0; + runCPU(argc, argv); + // if( cutCheckCmdLineFlag(argc, (const char**)argv, "device") ) + // cutilDeviceInit(argc, argv); + // else + // cudaSetDevice( cutGetMaxGflopsDeviceId() ); + // cudaSetDevice( 1); + + // CUT_EXIT(argc, argv); +} +/////////////////////////////// +// FUNCTION: only run CPU version +//////////////////////////////////////////// +void runCPU(int argc, char **argv) { + + struct pb_Parameters *params; + struct pb_TimerSet timers; + + pb_InitializeTimerSet(&timers); + params = pb_ReadParameters(&argc, argv); + if ((params->inpFiles[0] == NULL) || (params->inpFiles[1] != NULL)) { + fprintf(stderr, "Expecting one input filename\n"); + exit(-1); + } + + pb_SwitchToTimer(&timers, pb_TimerID_IO); + // printf("Reading File\n"); + // Read in Graph from a file + fp = fopen(params->inpFiles[0], "r"); + if (!fp) { + printf("Error Reading graph file\n"); + return; + } + + int source; + + fscanf(fp, "%d", &no_of_nodes); + // allocate host memory + Node *h_graph_nodes = (Node *)malloc(sizeof(Node) * no_of_nodes); + int *color = (int *)malloc(sizeof(int) * no_of_nodes); + int start, edgeno; + // initalize the memory + for (unsigned int i = 0; i < no_of_nodes; i++) { + fscanf(fp, "%d %d", &start, &edgeno); + h_graph_nodes[i].x = start; + h_graph_nodes[i].y = edgeno; + color[i] = WHITE; + } + // read the source node from the file + fscanf(fp, "%d", &source); + fscanf(fp, "%d", &edge_list_size); + int id, cost; + Edge *h_graph_edges = (Edge *)malloc(sizeof(Edge) * edge_list_size); + for (int i = 0; i < edge_list_size; i++) { + fscanf(fp, "%d", &id); + fscanf(fp, "%d", &cost); + h_graph_edges[i].x = id; + h_graph_edges[i].y = cost; + } + if (fp) + fclose(fp); + + // printf("Read File\n"); + + // allocate mem for the result on host side + int *h_cost = (int *)malloc(sizeof(int) * no_of_nodes); + for (int i = 0; i < no_of_nodes; i++) { + h_cost[i] = INF; + } + h_cost[source] = 0; + // printf("start cpu version\n"); + unsigned int cpu_timer = 0; + pb_SwitchToTimer(&timers, pb_TimerID_COMPUTE); + BFS_CPU(h_graph_nodes, h_graph_edges, color, h_cost, source); + pb_SwitchToTimer(&timers, pb_TimerID_IO); + if (params->outFile != NULL) { + // printf("Result stored in %s\n", params->outFile); + FILE *fp = fopen(params->outFile, "w"); + fprintf(fp, "%d\n", no_of_nodes); + for (int i = 0; i < no_of_nodes; i++) + fprintf(fp, "%d %d\n", i, h_cost[i]); + fclose(fp); + } + + pb_SwitchToTimer(&timers, pb_TimerID_COMPUTE); + // cleanup memory + free(h_graph_nodes); + free(h_graph_edges); + free(color); + free(h_cost); + pb_SwitchToTimer(&timers, pb_TimerID_NONE); + pb_PrintTimerSet(&timers); + pb_FreeParameters(params); +} +/////////////////////////////// +// FUNCTION:only run GPU version +//////////////////////////////////////////// diff --git a/test_suite/PROGRAMS/bfs/parboil.cpp b/test_suite/PROGRAMS/bfs/parboil.cpp old mode 100755 new mode 100644 index 0884419a..5946852a --- a/test_suite/PROGRAMS/bfs/parboil.cpp +++ b/test_suite/PROGRAMS/bfs/parboil.cpp @@ -3,38 +3,39 @@ */ #include "parboil.h" + +#include #include #include -#include #if _POSIX_VERSION >= 200112L -# include +#include #endif /* Free an array of owned strings. */ -static void -free_string_array(char **string_array) -{ +static void free_string_array(char **string_array) { char **p; - if (!string_array) return; - for (p = string_array; *p; p++) free(*p); + if (!string_array) + return; + for (p = string_array; *p; p++) + free(*p); free(string_array); } /* Parse a comma-delimited list of strings into an * array of strings. */ -static char ** -read_string_array(char *in) -{ +static char **read_string_array(char *in) { char **ret; int i; - int count; /* Number of items in the input */ - char *substring; /* Current substring within 'in' */ + int count; /* Number of items in the input */ + char *substring; /* Current substring within 'in' */ /* Count the number of items in the string */ count = 1; - for (i = 0; in[i]; i++) if (in[i] == ',') count++; + for (i = 0; in[i]; i++) + if (in[i] == ',') + count++; /* Allocate storage */ ret = (char **)malloc((count + 1) * sizeof(char *)); @@ -47,8 +48,8 @@ read_string_array(char *in) /* Find length of substring */ for (substring_end = substring; - (*substring_end != ',') && (*substring_end != 0); - substring_end++); + (*substring_end != ',') && (*substring_end != 0); substring_end++) + ; substring_length = substring_end - substring; @@ -60,41 +61,35 @@ read_string_array(char *in) /* go to next substring */ substring = substring_end + 1; } - ret[i] = NULL; /* Write the sentinel value */ + ret[i] = NULL; /* Write the sentinel value */ return ret; } struct argparse { - int argc; /* Number of arguments. Mutable. */ - char **argv; /* Argument values. Immutable. */ + int argc; /* Number of arguments. Mutable. */ + char **argv; /* Argument values. Immutable. */ - int argn; /* Current argument number. */ - char **argv_get; /* Argument value being read. */ - char **argv_put; /* Argument value being written. - * argv_put <= argv_get. */ + int argn; /* Current argument number. */ + char **argv_get; /* Argument value being read. */ + char **argv_put; /* Argument value being written. + * argv_put <= argv_get. */ }; -static void -initialize_argparse(struct argparse *ap, int argc, char **argv) -{ +static void initialize_argparse(struct argparse *ap, int argc, char **argv) { ap->argc = argc; ap->argn = 0; ap->argv_get = ap->argv_put = ap->argv = argv; } -static void -finalize_argparse(struct argparse *ap) -{ +static void finalize_argparse(struct argparse *ap) { /* Move the remaining arguments */ - for(; ap->argn < ap->argc; ap->argn++) + for (; ap->argn < ap->argc; ap->argn++) *ap->argv_put++ = *ap->argv_get++; } /* Delete the current argument. */ -static void -delete_argument(struct argparse *ap) -{ +static void delete_argument(struct argparse *ap) { if (ap->argn >= ap->argc) { fprintf(stderr, "delete_argument\n"); } @@ -104,9 +99,7 @@ delete_argument(struct argparse *ap) /* Go to the next argument. Also, move the current argument to its * final location in argv. */ -static void -next_argument(struct argparse *ap) -{ +static void next_argument(struct argparse *ap) { if (ap->argn >= ap->argc) { fprintf(stderr, "next_argument\n"); } @@ -115,33 +108,25 @@ next_argument(struct argparse *ap) ap->argn++; } -static int -is_end_of_arguments(struct argparse *ap) -{ +static int is_end_of_arguments(struct argparse *ap) { return ap->argn == ap->argc; } -static char * -get_argument(struct argparse *ap) -{ +static char *get_argument(struct argparse *ap) { return *ap->argv_get; } -static char * -consume_argument(struct argparse *ap) -{ +static char *consume_argument(struct argparse *ap) { char *ret = get_argument(ap); delete_argument(ap); return ret; } -struct pb_Parameters * -pb_ReadParameters(int *_argc, char **argv) -{ +struct pb_Parameters *pb_ReadParameters(int *_argc, char **argv) { char *err_message; struct argparse ap; struct pb_Parameters *ret = - (struct pb_Parameters *)malloc(sizeof(struct pb_Parameters)); + (struct pb_Parameters *)malloc(sizeof(struct pb_Parameters)); /* Initialize the parameters structure */ ret->outFile = NULL; @@ -150,59 +135,54 @@ pb_ReadParameters(int *_argc, char **argv) /* Each argument */ initialize_argparse(&ap, *_argc, argv); - while(!is_end_of_arguments(&ap)) { + while (!is_end_of_arguments(&ap)) { char *arg = get_argument(&ap); /* Single-character flag */ if ((arg[0] == '-') && (arg[1] != 0) && (arg[2] == 0)) { - delete_argument(&ap); /* This argument is consumed here */ - - switch(arg[1]) { - case 'o': /* Output file name */ - if (is_end_of_arguments(&ap)) - { - err_message = "Expecting file name after '-o'\n"; - goto error; - } - free(ret->outFile); - ret->outFile = strdup(consume_argument(&ap)); - break; - case 'i': /* Input file name */ - if (is_end_of_arguments(&ap)) - { - err_message = "Expecting file name after '-i'\n"; - goto error; - } - ret->inpFiles = read_string_array(consume_argument(&ap)); - break; - case '-': /* End of options */ - goto end_of_options; + delete_argument(&ap); /* This argument is consumed here */ + + switch (arg[1]) { + case 'o': /* Output file name */ + if (is_end_of_arguments(&ap)) { + err_message = "Expecting file name after '-o'\n"; + goto error; + } + free(ret->outFile); + ret->outFile = strdup(consume_argument(&ap)); + break; + case 'i': /* Input file name */ + if (is_end_of_arguments(&ap)) { + err_message = "Expecting file name after '-i'\n"; + goto error; + } + ret->inpFiles = read_string_array(consume_argument(&ap)); + break; + case '-': /* End of options */ + goto end_of_options; default: - err_message = "Unexpected command-line parameter\n"; - goto error; + err_message = "Unexpected command-line parameter\n"; + goto error; } - } - else { + } else { /* Other parameters are ignored */ next_argument(&ap); } } /* end for each argument */ - end_of_options: - *_argc = ap.argc; /* Save the modified argc value */ +end_of_options: + *_argc = ap.argc; /* Save the modified argc value */ finalize_argparse(&ap); return ret; - error: +error: fputs(err_message, stderr); pb_FreeParameters(ret); return NULL; } -void -pb_FreeParameters(struct pb_Parameters *p) -{ +void pb_FreeParameters(struct pb_Parameters *p) { char **cpp; free(p->outFile); @@ -210,56 +190,47 @@ pb_FreeParameters(struct pb_Parameters *p) free(p); } -int -pb_Parameters_CountInputs(struct pb_Parameters *p) -{ +int pb_Parameters_CountInputs(struct pb_Parameters *p) { int n; - for (n = 0; p->inpFiles[n]; n++); + for (n = 0; p->inpFiles[n]; n++) + ; return n; } /*****************************************************************************/ /* Timer routines */ -static void -accumulate_time(pb_Timestamp *accum, - pb_Timestamp start, - pb_Timestamp end) -{ +static void accumulate_time(pb_Timestamp *accum, pb_Timestamp start, + pb_Timestamp end) { #if _POSIX_VERSION >= 200112L *accum += end - start; #else -# error "Timestamps not implemented for this system" +#error "Timestamps not implemented for this system" #endif } #if _POSIX_VERSION >= 200112L -static pb_Timestamp get_time() -{ +static pb_Timestamp get_time() { struct timeval tv; gettimeofday(&tv, NULL); - return (pb_Timestamp) (tv.tv_sec * 1000000LL + tv.tv_usec); + return (pb_Timestamp)(tv.tv_sec * 1000000LL + tv.tv_usec); } #else -# error "no supported time libraries are available on this platform" +#error "no supported time libraries are available on this platform" #endif -void -pb_ResetTimer(struct pb_Timer *timer) -{ +void pb_ResetTimer(struct pb_Timer *timer) { timer->state = pb_Timer_STOPPED; #if _POSIX_VERSION >= 200112L timer->elapsed = 0; #else -# error "pb_ResetTimer: not implemented for this system" +#error "pb_ResetTimer: not implemented for this system" #endif } -void -pb_StartTimer(struct pb_Timer *timer) -{ +void pb_StartTimer(struct pb_Timer *timer) { if (timer->state != pb_Timer_STOPPED) { fputs("Ignoring attempt to start a running timer\n", stderr); return; @@ -274,13 +245,12 @@ pb_StartTimer(struct pb_Timer *timer) timer->init = tv.tv_sec * 1000000LL + tv.tv_usec; } #else -# error "pb_StartTimer: not implemented for this system" +#error "pb_StartTimer: not implemented for this system" #endif } -void -pb_StartTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) -{ +void pb_StartTimerAndSubTimer(struct pb_Timer *timer, + struct pb_Timer *subtimer) { unsigned int numNotStopped = 0x3; // 11 if (timer->state != pb_Timer_STOPPED) { fputs("Warning: Timer was not stopped\n", stderr); @@ -302,24 +272,21 @@ pb_StartTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) { struct timeval tv; gettimeofday(&tv, NULL); - + if (numNotStopped & 0x2) { timer->init = tv.tv_sec * 1000000LL + tv.tv_usec; } - + if (numNotStopped & 0x1) { subtimer->init = tv.tv_sec * 1000000LL + tv.tv_usec; } } #else -# error "pb_StartTimer: not implemented for this system" +#error "pb_StartTimer: not implemented for this system" #endif - } -void -pb_StopTimer(struct pb_Timer *timer) -{ +void pb_StopTimer(struct pb_Timer *timer) { pb_Timestamp fini; @@ -337,15 +304,15 @@ pb_StopTimer(struct pb_Timer *timer) fini = tv.tv_sec * 1000000LL + tv.tv_usec; } #else -# error "pb_StopTimer: not implemented for this system" +#error "pb_StopTimer: not implemented for this system" #endif accumulate_time(&timer->elapsed, timer->init, fini); timer->init = fini; - } -void pb_StopTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) { +void pb_StopTimerAndSubTimer(struct pb_Timer *timer, + struct pb_Timer *subtimer) { pb_Timestamp fini; @@ -363,7 +330,6 @@ void pb_StopTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) return; } - timer->state = pb_Timer_STOPPED; subtimer->state = pb_Timer_STOPPED; @@ -374,25 +340,22 @@ void pb_StopTimerAndSubTimer(struct pb_Timer *timer, struct pb_Timer *subtimer) fini = tv.tv_sec * 1000000LL + tv.tv_usec; } #else -# error "pb_StopTimer: not implemented for this system" +#error "pb_StopTimer: not implemented for this system" #endif if (numNotRunning & 0x2) { accumulate_time(&timer->elapsed, timer->init, fini); timer->init = fini; } - + if (numNotRunning & 0x1) { accumulate_time(&subtimer->elapsed, subtimer->init, fini); subtimer->init = fini; } - } /* Get the elapsed time in seconds. */ -double -pb_GetElapsedTime(struct pb_Timer *timer) -{ +double pb_GetElapsedTime(struct pb_Timer *timer) { double ret; if (timer->state != pb_Timer_STOPPED) { @@ -402,22 +365,19 @@ pb_GetElapsedTime(struct pb_Timer *timer) #if _POSIX_VERSION >= 200112L ret = timer->elapsed / 1e6; #else -# error "pb_GetElapsedTime: not implemented for this system" +#error "pb_GetElapsedTime: not implemented for this system" #endif return ret; } -void -pb_InitializeTimerSet(struct pb_TimerSet *timers) -{ +void pb_InitializeTimerSet(struct pb_TimerSet *timers) { int n; - + timers->wall_begin = get_time(); timers->current = pb_TimerID_NONE; timers->async_markers = NULL; - for (n = 0; n < pb_TimerID_LAST; n++) { pb_ResetTimer(&timers->timers[n]); @@ -425,24 +385,24 @@ pb_InitializeTimerSet(struct pb_TimerSet *timers) } } -void -pb_AddSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID pb_Category) { - - struct pb_SubTimer *subtimer = (struct pb_SubTimer *) malloc - (sizeof(struct pb_SubTimer)); - +void pb_AddSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID pb_Category) { + + struct pb_SubTimer *subtimer = + (struct pb_SubTimer *)malloc(sizeof(struct pb_SubTimer)); + int len = strlen(label); - - subtimer->label = (char *) malloc (sizeof(char)*(len+1)); + + subtimer->label = (char *)malloc(sizeof(char) * (len + 1)); sprintf(subtimer->label, "%s\0", label); - + pb_ResetTimer(&subtimer->timer); subtimer->next = NULL; - + struct pb_SubTimerList *subtimerlist = timers->sub_timer_list[pb_Category]; if (subtimerlist == NULL) { - subtimerlist = (struct pb_SubTimerList *) malloc - (sizeof(struct pb_SubTimerList)); + subtimerlist = + (struct pb_SubTimerList *)malloc(sizeof(struct pb_SubTimerList)); subtimerlist->subtimer_list = subtimer; timers->sub_timer_list[pb_Category] = subtimerlist; } else { @@ -453,28 +413,30 @@ pb_AddSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID pb_Categ } element->next = subtimer; } - } -void -pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID category) -{ +void pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID category) { + + // switchToSub( NULL, NONE + // switchToSub( NULL, some + // switchToSub( some, some + // switchToSub( some, NONE -- tries to find "some" in NONE's sublist, which + // won't be printed -// switchToSub( NULL, NONE -// switchToSub( NULL, some -// switchToSub( some, some -// switchToSub( some, NONE -- tries to find "some" in NONE's sublist, which won't be printed - struct pb_Timer *topLevelToStop = NULL; if (timers->current != category && timers->current != pb_TimerID_NONE) { - // Switching to subtimer in a different category needs to stop the top-level current, different categoried timer. - // NONE shouldn't have a timer associated with it, so exclude from branch + // Switching to subtimer in a different category needs to stop the top-level + // current, different categoried timer. NONE shouldn't have a timer + // associated with it, so exclude from branch topLevelToStop = &timers->timers[timers->current]; - } + } + + struct pb_SubTimerList *subtimerlist = + timers->sub_timer_list[timers->current]; + struct pb_SubTimer *curr = + (subtimerlist == NULL) ? NULL : subtimerlist->current; - struct pb_SubTimerList *subtimerlist = timers->sub_timer_list[timers->current]; - struct pb_SubTimer *curr = (subtimerlist == NULL) ? NULL : subtimerlist->current; - if (timers->current != pb_TimerID_NONE) { if (curr != NULL && topLevelToStop != NULL) { pb_StopTimerAndSubTimer(topLevelToStop, &curr->timer); @@ -484,11 +446,11 @@ pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID cat pb_StopTimer(topLevelToStop); } } - + subtimerlist = timers->sub_timer_list[category]; struct pb_SubTimer *subtimer = NULL; - - if (label != NULL) { + + if (label != NULL) { subtimer = subtimerlist->subtimer_list; while (subtimer != NULL) { if (strcmp(subtimer->label, label) == 0) { @@ -497,46 +459,45 @@ pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID cat subtimer = subtimer->next; } } - } - + } + if (category != pb_TimerID_NONE) { - + if (subtimerlist != NULL) { subtimerlist->current = subtimer; } - + if (category != timers->current && subtimer != NULL) { pb_StartTimerAndSubTimer(&timers->timers[category], &subtimer->timer); } else if (subtimer != NULL) { // Same category, different non-NULL subtimer pb_StartTimer(&subtimer->timer); - } else{ - // Different category, but no subtimer (not found or specified as NULL) -- unprefered way of setting topLevel timer + } else { + // Different category, but no subtimer (not found or specified as NULL) -- + // unprefered way of setting topLevel timer pb_StartTimer(&timers->timers[category]); } - } - + } + timers->current = category; - } -void -pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer) -{ +void pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer) { /* Stop the currently running timer */ if (timers->current != pb_TimerID_NONE) { struct pb_SubTimer *currSubTimer = NULL; - struct pb_SubTimerList *subtimerlist = timers->sub_timer_list[timers->current]; - - if ( subtimerlist != NULL) { + struct pb_SubTimerList *subtimerlist = + timers->sub_timer_list[timers->current]; + + if (subtimerlist != NULL) { currSubTimer = timers->sub_timer_list[timers->current]->current; } - if ( currSubTimer!= NULL) { - pb_StopTimerAndSubTimer(&timers->timers[timers->current], &currSubTimer->timer); + if (currSubTimer != NULL) { + pb_StopTimerAndSubTimer(&timers->timers[timers->current], + &currSubTimer->timer); } else { pb_StopTimer(&timers->timers[timers->current]); } - } timers->current = timer; @@ -546,30 +507,29 @@ pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer) } } -void -pb_PrintTimerSet(struct pb_TimerSet *timers) -{ +void pb_PrintTimerSet(struct pb_TimerSet *timers) { pb_Timestamp wall_end = get_time(); struct pb_Timer *t = timers->timers; - struct pb_SubTimer* sub = NULL; - + struct pb_SubTimer *sub = NULL; + int maxSubLength; - - const char *categories[] = { - "IO", "Kernel", "Copy", "Driver", "Copy Async", "Compute" - }; - + + const char *categories[] = {"IO", "Kernel", "Copy", + "Driver", "Copy Async", "Compute"}; + const int maxCategoryLength = 10; - + int i; - for(i = 1; i < pb_TimerID_LAST-1; ++i) { // exclude NONE and OVRELAP from this format - if(pb_GetElapsedTime(&t[i]) != 0) { - + for (i = 1; i < pb_TimerID_LAST - 1; + ++i) { // exclude NONE and OVRELAP from this format + if (pb_GetElapsedTime(&t[i]) != 0) { + // Print Category Timer - printf("%-*s: %f\n", maxCategoryLength, categories[i-1], pb_GetElapsedTime(&t[i])); - + printf("%-*s: %f\n", maxCategoryLength, categories[i - 1], + pb_GetElapsedTime(&t[i])); + if (timers->sub_timer_list[i] != NULL) { sub = timers->sub_timer_list[i]->subtimer_list; maxSubLength = 0; @@ -580,44 +540,44 @@ pb_PrintTimerSet(struct pb_TimerSet *timers) } sub = sub->next; } - + // Fit to Categories if (maxSubLength <= maxCategoryLength) { - maxSubLength = maxCategoryLength; + maxSubLength = maxCategoryLength; } - + sub = timers->sub_timer_list[i]->subtimer_list; - + // Print SubTimers while (sub != NULL) { - printf(" -%-*s: %f\n", maxSubLength, sub->label, pb_GetElapsedTime(&sub->timer)); + printf(" -%-*s: %f\n", maxSubLength, sub->label, + pb_GetElapsedTime(&sub->timer)); sub = sub->next; } } } } - - if(pb_GetElapsedTime(&t[pb_TimerID_OVERLAP]) != 0) - printf("CPU/Kernel Overlap: %f\n", pb_GetElapsedTime(&t[pb_TimerID_OVERLAP])); - - float walltime = (wall_end - timers->wall_begin)/ 1e6; - printf("Timer Wall Time: %f\n", walltime); - + + if (pb_GetElapsedTime(&t[pb_TimerID_OVERLAP]) != 0) + printf("CPU/Kernel Overlap: %f\n", + pb_GetElapsedTime(&t[pb_TimerID_OVERLAP])); + + float walltime = (wall_end - timers->wall_begin) / 1e6; + printf("Timer Wall Time: %f\n", walltime); } -void pb_DestroyTimerSet(struct pb_TimerSet * timers) -{ +void pb_DestroyTimerSet(struct pb_TimerSet *timers) { /* clean up all of the async event markers */ - struct pb_async_time_marker_list ** event = &(timers->async_markers); - while( *event != NULL) { - struct pb_async_time_marker_list ** next = &((*event)->next); + struct pb_async_time_marker_list **event = &(timers->async_markers); + while (*event != NULL) { + struct pb_async_time_marker_list **next = &((*event)->next); free(*event); (*event) = NULL; event = next; } - + int i = 0; - for(i = 0; i < pb_TimerID_LAST; ++i) { + for (i = 0; i < pb_TimerID_LAST; ++i) { if (timers->sub_timer_list[i] != NULL) { struct pb_SubTimer *subtimer = timers->sub_timer_list[i]->subtimer_list; struct pb_SubTimer *prev = NULL; @@ -631,5 +591,3 @@ void pb_DestroyTimerSet(struct pb_TimerSet * timers) } } } - - diff --git a/test_suite/PROGRAMS/bfs/parboil.h b/test_suite/PROGRAMS/bfs/parboil.h old mode 100755 new mode 100644 index bc4891f2..f2222c67 --- a/test_suite/PROGRAMS/bfs/parboil.h +++ b/test_suite/PROGRAMS/bfs/parboil.h @@ -12,13 +12,13 @@ extern "C" { /* Command line parameters for benchmarks */ struct pb_Parameters { - char *outFile; /* If not NULL, the raw output of the - * computation should be saved to this - * file. The string is owned. */ - char **inpFiles; /* A NULL-terminated array of strings - * holding the input file(s) for the - * computation. The array and strings - * are owned. */ + char *outFile; /* If not NULL, the raw output of the + * computation should be saved to this + * file. The string is owned. */ + char **inpFiles; /* A NULL-terminated array of strings + * holding the input file(s) for the + * computation. The array and strings + * are owned. */ }; /* Read command-line parameters. @@ -30,24 +30,21 @@ struct pb_Parameters { * If there is an error, then an error message is printed on stderr * and NULL is returned. */ -struct pb_Parameters * -pb_ReadParameters(int *_argc, char **argv); +struct pb_Parameters *pb_ReadParameters(int *_argc, char **argv); /* Free an instance of struct pb_Parameters. */ -void -pb_FreeParameters(struct pb_Parameters *p); +void pb_FreeParameters(struct pb_Parameters *p); /* Count the number of input files in a pb_Parameters instance. */ -int -pb_Parameters_CountInputs(struct pb_Parameters *p); +int pb_Parameters_CountInputs(struct pb_Parameters *p); /* A time or duration. */ #if _POSIX_VERSION >= 200112L typedef unsigned long long pb_Timestamp; /* time in microseconds */ #else -# error "Timestamps not implemented" +#error "Timestamps not implemented" #endif enum pb_TimerState { @@ -57,68 +54,64 @@ enum pb_TimerState { struct pb_Timer { enum pb_TimerState state; - pb_Timestamp elapsed; /* Amount of time elapsed so far */ - pb_Timestamp init; /* Beginning of the current time interval, - * if state is RUNNING. End of the last - * recorded time interfal otherwise. */ + pb_Timestamp elapsed; /* Amount of time elapsed so far */ + pb_Timestamp init; /* Beginning of the current time interval, + * if state is RUNNING. End of the last + * recorded time interfal otherwise. */ }; /* Reset a timer. * Use this to initialize a timer or to clear * its elapsed time. The reset timer is stopped. */ -void -pb_ResetTimer(struct pb_Timer *timer); +void pb_ResetTimer(struct pb_Timer *timer); /* Start a timer. The timer is set to RUNNING mode and * time elapsed while the timer is running is added to * the timer. * The timer should not already be running. */ -void -pb_StartTimer(struct pb_Timer *timer); +void pb_StartTimer(struct pb_Timer *timer); /* Stop a timer. * This stops adding elapsed time to the timer. * The timer should not already be stopped. */ -void -pb_StopTimer(struct pb_Timer *timer); +void pb_StopTimer(struct pb_Timer *timer); /* Get the elapsed time in seconds. */ -double -pb_GetElapsedTime(struct pb_Timer *timer); +double pb_GetElapsedTime(struct pb_Timer *timer); /* Execution time is assigned to one of these categories. */ enum pb_TimerID { pb_TimerID_NONE = 0, - pb_TimerID_IO, /* Time spent in input/output */ - pb_TimerID_KERNEL, /* Time spent computing on the device, - * recorded asynchronously */ - pb_TimerID_COPY, /* Time spent synchronously moving data - * to/from device and allocating/freeing - * memory on the device */ - pb_TimerID_DRIVER, /* Time spent in the host interacting with the - * driver, primarily for recording the time - * spent queueing asynchronous operations */ - pb_TimerID_COPY_ASYNC, /* Time spent in asynchronous transfers */ - pb_TimerID_COMPUTE, /* Time for all program execution other - * than parsing command line arguments, - * I/O, kernel, and copy */ - pb_TimerID_OVERLAP, /* Time double-counted in asynchronous and - * host activity: automatically filled in, - * not intended for direct usage */ - pb_TimerID_LAST /* Number of timer IDs */ + pb_TimerID_IO, /* Time spent in input/output */ + pb_TimerID_KERNEL, /* Time spent computing on the device, + * recorded asynchronously */ + pb_TimerID_COPY, /* Time spent synchronously moving data + * to/from device and allocating/freeing + * memory on the device */ + pb_TimerID_DRIVER, /* Time spent in the host interacting with the + * driver, primarily for recording the time + * spent queueing asynchronous operations */ + pb_TimerID_COPY_ASYNC, /* Time spent in asynchronous transfers */ + pb_TimerID_COMPUTE, /* Time for all program execution other + * than parsing command line arguments, + * I/O, kernel, and copy */ + pb_TimerID_OVERLAP, /* Time double-counted in asynchronous and + * host activity: automatically filled in, + * not intended for direct usage */ + pb_TimerID_LAST /* Number of timer IDs */ }; /* Dynamic list of asynchronously tracked times between events */ struct pb_async_time_marker_list { - char *label; // actually just a pointer to a string - enum pb_TimerID timerID; /* The ID to which the interval beginning - * with this marker should be attributed */ - void * marker; - //cudaEvent_t marker; /* The driver event for this marker */ - struct pb_async_time_marker_list *next; + char *label; // actually just a pointer to a string + enum pb_TimerID timerID; /* The ID to which the interval beginning + * with this marker should be attributed */ + void *marker; + // cudaEvent_t marker; /* The driver event for this marker */ + struct pb_async_time_marker_list *next; }; struct pb_SubTimer { @@ -135,7 +128,7 @@ struct pb_SubTimerList { /* A set of timers for recording execution times. */ struct pb_TimerSet { enum pb_TimerID current; - struct pb_async_time_marker_list* async_markers; + struct pb_async_time_marker_list *async_markers; pb_Timestamp async_begin; pb_Timestamp wall_begin; struct pb_Timer timers[pb_TimerID_LAST]; @@ -143,34 +136,29 @@ struct pb_TimerSet { }; /* Reset all timers in the set. */ -void -pb_InitializeTimerSet(struct pb_TimerSet *timers); +void pb_InitializeTimerSet(struct pb_TimerSet *timers); -void -pb_AddSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID pb_Category); +void pb_AddSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID pb_Category); /* Select which timer the next interval of time should be accounted * to. The selected timer is started and other timers are stopped. * Using pb_TimerID_NONE stops all timers. */ -void -pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); +void pb_SwitchToTimer(struct pb_TimerSet *timers, enum pb_TimerID timer); -void -pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, enum pb_TimerID category); +void pb_SwitchToSubTimer(struct pb_TimerSet *timers, char *label, + enum pb_TimerID category); /* Print timer values to standard output. */ -void -pb_PrintTimerSet(struct pb_TimerSet *timers); +void pb_PrintTimerSet(struct pb_TimerSet *timers); /* Release timer resources */ -void -pb_DestroyTimerSet(struct pb_TimerSet * timers); +void pb_DestroyTimerSet(struct pb_TimerSet *timers); -void -pb_SetOpenCL(void *clContextPtr, void *clCommandQueuePtr); +void pb_SetOpenCL(void *clContextPtr, void *clCommandQueuePtr); #ifdef __cplusplus } #endif -#endif //PARBOIL_HEADER +#endif // PARBOIL_HEADER diff --git a/test_suite/PROGRAMS/factorial/factorial.c b/test_suite/PROGRAMS/factorial/factorial.c index 14b39385..40793ea6 100644 --- a/test_suite/PROGRAMS/factorial/factorial.c +++ b/test_suite/PROGRAMS/factorial/factorial.c @@ -1,15 +1,14 @@ #include #include -main(argc, argv) -int argc; -char *argv[]; + +int main(int argc, char *argv[]) { -int i,fact, n; -n = atoi(argv[1]); -fact = 1; -for(i=1;i<=n;i++) -{ -fact = fact * i; -} -printf("%d\n",fact); + int i, fact, n; + n = atoi(argv[1]); + fact = 1; + for (i = 1; i <= n; i++) { + fact = fact * i; + } + printf("%d\n", fact); + return 0; } diff --git a/test_suite/PROGRAMS/mcf/defines.h b/test_suite/PROGRAMS/mcf/defines.h index 943c5986..2a7f3e57 100644 --- a/test_suite/PROGRAMS/mcf/defines.h +++ b/test_suite/PROGRAMS/mcf/defines.h @@ -1,7 +1,7 @@ /************************************************************************** DEFINES.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,15 +11,13 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Thu Feb 17 21:46:48 2005 by Andreas Loebel (boss.local.de) */ /* $Id: defines.h,v 1.13 2005/02/17 21:43:12 bzfloebe Exp $ */ - - #ifndef _DEFINES_H #define _DEFINES_H @@ -27,104 +25,87 @@ Copyright (c) 2003-2005 Andreas Loebel. #ifndef _WIN32 #include #endif +#include +#include +#include +#include #include #include -#include #include -#include -#include -#include #ifdef INTERNAL_TIMING -#include -#include #include +#include +#include #endif - #include "prototyp.h" - -#define UNBOUNDED 1000000000 -#define ZERO 0 -#define MAX_ART_COST (long)(100000000L) +#define UNBOUNDED 1000000000 +#define ZERO 0 +#define MAX_ART_COST (long)(100000000L) #define ARITHMETIC_TYPE "I" - - -#define FIXED -1 -#define BASIC 0 -#define AT_LOWER 1 -#define AT_UPPER 2 +#define FIXED -1 +#define BASIC 0 +#define AT_LOWER 1 +#define AT_UPPER 2 /* #define AT_ZERO 3 NOT ALLOWED FOR THE SPEC VERSION */ #undef AT_ZERO - -#define UP 1 -#define DOWN 0 - - +#define UP 1 +#define DOWN 0 typedef long flow_t; typedef long cost_t; - - - #ifndef NULL #define NULL 0 #endif - #ifndef ABS -#define ABS( x ) ( ((x) >= 0) ? ( x ) : -( x ) ) +#define ABS(x) (((x) >= 0) ? (x) : -(x)) #endif - #ifndef MAX -#define MAX(a,b) (((a) > (b)) ? (a) : (b)) +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) #endif - #ifndef SET_ZERO -#define SET_ZERO( vec, n ) if( vec ) memset( (void *)vec, 0, (size_t)n ) +#define SET_ZERO(vec, n) \ + if (vec) \ + memset((void *)vec, 0, (size_t)n) #endif - #ifndef FREE -#define FREE( vec ) if( vec ) free( (void *)vec ) +#define FREE(vec) \ + if (vec) \ + free((void *)vec) #endif - typedef struct node node_t; typedef struct node *node_p; typedef struct arc arc_t; typedef struct arc *arc_p; - - -struct node -{ - cost_t potential; +struct node { + cost_t potential; int orientation; node_p child; node_p pred; node_p sibling; - node_p sibling_prev; - arc_p basic_arc; + node_p sibling_prev; + arc_p basic_arc; arc_p firstout, firstin; arc_p arc_tmp; flow_t flow; - long depth; + long depth; int number; int time; }; - - -struct arc -{ +struct arc { cost_t cost; node_p tail, head; int ident; @@ -133,16 +114,13 @@ struct arc cost_t org_cost; }; - - -typedef struct network -{ +typedef struct network { char inputfile[200]; char clustfile[200]; long n, n_trips; long max_m, m, m_org, m_impl; long max_residual_new_m, max_new_m; - + long primal_unbounded; long dual_unbounded; long perturbed; @@ -152,15 +130,14 @@ typedef struct network long feas_tol; long pert_val; long bigM; - double optcost; + double optcost; cost_t ignore_impl; node_p nodes, stop_nodes; arc_p arcs, stop_arcs; - arc_p dummy_arcs, stop_dummy; + arc_p dummy_arcs, stop_dummy; long iterations; long bound_exchanges; long checksum; } network_t; - #endif diff --git a/test_suite/PROGRAMS/mcf/implicit.h b/test_suite/PROGRAMS/mcf/implicit.h index a6b35b49..9478c69b 100644 --- a/test_suite/PROGRAMS/mcf/implicit.h +++ b/test_suite/PROGRAMS/mcf/implicit.h @@ -1,7 +1,7 @@ /************************************************************************** IMPLICIT.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,24 +11,20 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:18 2004 by Andreas Loebel (boss.local.de) */ /* $Id: implicit.h,v 1.11 2005/02/17 19:42:21 bzfloebe Exp $ */ - #ifndef _IMPLICIT_H #define _IMPLICIT_H - -#include "mcfutil.h" #include "mcflimit.h" +#include "mcfutil.h" - -extern long price_out_impl _PROTO_(( network_t * )); -extern long suspend_impl _PROTO_(( network_t *, cost_t, long )); - +extern long price_out_impl _PROTO_((network_t *)); +extern long suspend_impl _PROTO_((network_t *, cost_t, long)); #endif diff --git a/test_suite/PROGRAMS/mcf/mcf.h b/test_suite/PROGRAMS/mcf/mcf.h index 01e8841b..526824ea 100644 --- a/test_suite/PROGRAMS/mcf/mcf.h +++ b/test_suite/PROGRAMS/mcf/mcf.h @@ -1,7 +1,7 @@ /************************************************************************** MCF.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,28 +11,24 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:29 2004 by Andreas Loebel (boss.local.de) */ /* $Id: mcf.h,v 1.9 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _MCF_H #define _MCF_H - #include "defines.h" -#include "mcfutil.h" -#include "readmin.h" -#include "output.h" -#include "pstart.h" -#include "psimplex.h" -#include "pbeampp.h" #include "implicit.h" #include "limits.h" - +#include "mcfutil.h" +#include "output.h" +#include "pbeampp.h" +#include "psimplex.h" +#include "pstart.h" +#include "readmin.h" #endif diff --git a/test_suite/PROGRAMS/mcf/mcflimit.h b/test_suite/PROGRAMS/mcf/mcflimit.h index 9a831c55..ade762b4 100644 --- a/test_suite/PROGRAMS/mcf/mcflimit.h +++ b/test_suite/PROGRAMS/mcf/mcflimit.h @@ -1,7 +1,7 @@ /************************************************************************** MCFLIMIT.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,18 +11,16 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Thu Feb 17 22:24:36 2005 by Andreas Loebel (boss.local.de) */ /* $Id: mcflimit.h,v 1.12 2005/02/17 21:43:12 bzfloebe Exp $ */ - #ifndef _MCF_LIMITS_H #define _MCF_LIMITS_H - #define BIGM 1.0e7 #define STRECHT(x) ((long)(1.25 * (double)(x))) @@ -31,9 +29,8 @@ Copyright (c) 2003-2005 Andreas Loebel. #define MAX_NEW_ARCS_SMALL_NET 3000000 #define MAX_NEW_ARCS_LARGE_NET 28900000 -#define MAX_NB_ITERATIONS_SMALL_NET 5 -#define MAX_NB_ITERATIONS_LARGE_NET 5 - +#define MAX_NB_ITERATIONS_SMALL_NET 5 +#define MAX_NB_ITERATIONS_LARGE_NET 5 /* // Some operating systems and compiler, respectively, do not handle reallocs @@ -42,5 +39,4 @@ Copyright (c) 2003-2005 Andreas Loebel. */ #define SPEC_STATIC - #endif diff --git a/test_suite/PROGRAMS/mcf/mcfutil.h b/test_suite/PROGRAMS/mcf/mcfutil.h index 5ae24855..678525e4 100644 --- a/test_suite/PROGRAMS/mcf/mcfutil.h +++ b/test_suite/PROGRAMS/mcf/mcfutil.h @@ -1,7 +1,7 @@ /************************************************************************** MCFUTIL.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,29 +11,24 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:49 2004 by Andreas Loebel (boss.local.de) */ /* $Id: mcfutil.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _MCFUTIL_H #define _MCFUTIL_H - #include "defines.h" - -extern void refresh_neighbour_lists _PROTO_(( network_t * )); -extern long refresh_potential _PROTO_(( network_t * )); -extern double flow_cost _PROTO_(( network_t * )); -extern double flow_org_cost _PROTO_(( network_t * )); -extern long primal_feasible _PROTO_(( network_t * )); -extern long dual_feasible _PROTO_(( network_t * )); -extern long getfree _PROTO_(( network_t * )); - +extern void refresh_neighbour_lists _PROTO_((network_t *)); +extern long refresh_potential _PROTO_((network_t *)); +extern double flow_cost _PROTO_((network_t *)); +extern double flow_org_cost _PROTO_((network_t *)); +extern long primal_feasible _PROTO_((network_t *)); +extern long dual_feasible _PROTO_((network_t *)); +extern long getfree _PROTO_((network_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/output.h b/test_suite/PROGRAMS/mcf/output.h index 323ea485..a0b51286 100644 --- a/test_suite/PROGRAMS/mcf/output.h +++ b/test_suite/PROGRAMS/mcf/output.h @@ -1,7 +1,7 @@ /************************************************************************** OUTPUT.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:21:59 2004 by Andreas Loebel (boss.local.de) */ /* $Id: output.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _OUTPUT_H #define _OUTPUT_H - #include "mcfutil.h" - -extern long write_circulations _PROTO_(( char *, network_t * )); - +extern long write_circulations _PROTO_((char *, network_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/pbeampp.h b/test_suite/PROGRAMS/mcf/pbeampp.h index 834acd2d..03957b13 100644 --- a/test_suite/PROGRAMS/mcf/pbeampp.h +++ b/test_suite/PROGRAMS/mcf/pbeampp.h @@ -1,7 +1,7 @@ /************************************************************************** PBEAMPP.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:09 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pbeampp.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PBEAMPP_H #define _PBEAMPP_H - #include "defines.h" - -extern arc_t *primal_bea_mpp _PROTO_(( long, arc_t*, arc_t*, cost_t* )); - +extern arc_t *primal_bea_mpp _PROTO_((long, arc_t *, arc_t *, cost_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/pbla.h b/test_suite/PROGRAMS/mcf/pbla.h index 70fdbfd8..f062452b 100644 --- a/test_suite/PROGRAMS/mcf/pbla.h +++ b/test_suite/PROGRAMS/mcf/pbla.h @@ -1,7 +1,7 @@ /************************************************************************** PBLA.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,24 +11,19 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:19 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pbla.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PBLA_H #define _PBLA_H - #include "defines.h" - -extern node_t *primal_iminus _PROTO_(( flow_t *, long *, node_t *, - node_t *, node_t ** )); - +extern node_t *primal_iminus _PROTO_((flow_t *, long *, node_t *, node_t *, + node_t **)); #endif diff --git a/test_suite/PROGRAMS/mcf/pflowup.h b/test_suite/PROGRAMS/mcf/pflowup.h index 9f19c3cf..6209509f 100644 --- a/test_suite/PROGRAMS/mcf/pflowup.h +++ b/test_suite/PROGRAMS/mcf/pflowup.h @@ -1,7 +1,7 @@ /************************************************************************** PFLOWUP.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:28 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pflowup.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PFLOWUP_H #define _PFLOWUP_H - #include "defines.h" - -extern void primal_update_flow _PROTO_(( node_t *, node_t *, node_t * )); - +extern void primal_update_flow _PROTO_((node_t *, node_t *, node_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/prototyp.h b/test_suite/PROGRAMS/mcf/prototyp.h index d791d85e..eccd4173 100644 --- a/test_suite/PROGRAMS/mcf/prototyp.h +++ b/test_suite/PROGRAMS/mcf/prototyp.h @@ -1,7 +1,7 @@ /************************************************************************** PROTOTYPE.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,34 +11,23 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Wed Feb 16 20:24:10 2005 by Andreas Loebel (boss.local.de) */ /* $Id: prototyp.h,v 1.11 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PROTOTYP_H #define _PROTOTYP_H - #ifndef _PROTO_ -#if defined(__STDC__) || defined(__cplusplus) \ - || defined(WANT_STDC_PROTO) || defined(SPEC_CPU) -#define _PROTO_( args ) args +#if defined(__STDC__) || defined(__cplusplus) || defined(WANT_STDC_PROTO) || \ + defined(SPEC_CPU) +#define _PROTO_(args) args #else -#define _PROTO_( args ) +#define _PROTO_(args) #endif #endif - #endif - - - - - - - diff --git a/test_suite/PROGRAMS/mcf/psimplex.h b/test_suite/PROGRAMS/mcf/psimplex.h index 9562489b..27e983ae 100644 --- a/test_suite/PROGRAMS/mcf/psimplex.h +++ b/test_suite/PROGRAMS/mcf/psimplex.h @@ -1,7 +1,7 @@ /************************************************************************** PSIMPLEX.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,28 +11,23 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:48 2004 by Andreas Loebel (boss.local.de) */ /* $Id: psimplex.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PSIMPLEX_H #define _PSIMPLEX_H - #include "defines.h" +#include "mcfutil.h" #include "pbeampp.h" #include "pbla.h" #include "pflowup.h" #include "treeup.h" -#include "mcfutil.h" - - -extern long primal_net_simplex _PROTO_(( network_t * )); +extern long primal_net_simplex _PROTO_((network_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/pstart.h b/test_suite/PROGRAMS/mcf/pstart.h index becb4430..b0bc60b3 100644 --- a/test_suite/PROGRAMS/mcf/pstart.h +++ b/test_suite/PROGRAMS/mcf/pstart.h @@ -1,7 +1,7 @@ /************************************************************************** PSTART.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,23 +11,18 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:22:58 2004 by Andreas Loebel (boss.local.de) */ /* $Id: pstart.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _PSTART_H #define _PSTART_H - #include "defines.h" - -extern long primal_start_artificial _PROTO_(( network_t * )); - +extern long primal_start_artificial _PROTO_((network_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/readmin.h b/test_suite/PROGRAMS/mcf/readmin.h index 602efbcd..b7f9a661 100644 --- a/test_suite/PROGRAMS/mcf/readmin.h +++ b/test_suite/PROGRAMS/mcf/readmin.h @@ -1,7 +1,7 @@ /************************************************************************** READMIN.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,25 +11,20 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:23:07 2004 by Andreas Loebel (boss.local.de) */ /* $Id: readmin.h,v 1.11 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _READMIN_H #define _READMIN_H - #include "defines.h" -#include "mcfutil.h" #include "mcflimit.h" +#include "mcfutil.h" - -extern long read_min _PROTO_(( network_t * )); - +extern long read_min _PROTO_((network_t *)); #endif diff --git a/test_suite/PROGRAMS/mcf/treeup.h b/test_suite/PROGRAMS/mcf/treeup.h index c7fed7b3..0ca13ee0 100644 --- a/test_suite/PROGRAMS/mcf/treeup.h +++ b/test_suite/PROGRAMS/mcf/treeup.h @@ -1,7 +1,7 @@ /************************************************************************** TREEUP.H of ZIB optimizer MCF, SPEC version -This software was developed at ZIB Berlin. Maintenance and revisions +This software was developed at ZIB Berlin. Maintenance and revisions solely on responsibility of Andreas Loebel Dr. Andreas Loebel @@ -11,25 +11,20 @@ Konrad-Zuse-Zentrum fuer Informationstechnik Berlin (ZIB) Scientific Computing - Optimization Takustr. 7, 14195 Berlin-Dahlem -Copyright (c) 1998-2000 ZIB. -Copyright (c) 2000-2002 ZIB & Loebel. +Copyright (c) 1998-2000 ZIB. +Copyright (c) 2000-2002 ZIB & Loebel. Copyright (c) 2003-2005 Andreas Loebel. **************************************************************************/ /* LAST EDIT: Sun Nov 21 16:23:16 2004 by Andreas Loebel (boss.local.de) */ /* $Id: treeup.h,v 1.10 2005/02/17 19:42:21 bzfloebe Exp $ */ - - #ifndef _TREEUP_H #define _TREEUP_H - #include "defines.h" - -extern void update_tree _PROTO_(( long, long, flow_t, flow_t, node_t *, - node_t *, node_t *, node_t *, node_t *, - arc_t *, cost_t, flow_t )); - +extern void update_tree _PROTO_((long, long, flow_t, flow_t, node_t *, node_t *, + node_t *, node_t *, node_t *, arc_t *, cost_t, + flow_t)); #endif diff --git a/test_suite/SCRIPTS/CMakeLists.txt b/test_suite/SCRIPTS/CMakeLists.txt index ae8f87ed..3a135866 100644 --- a/test_suite/SCRIPTS/CMakeLists.txt +++ b/test_suite/SCRIPTS/CMakeLists.txt @@ -14,5 +14,8 @@ copy(inject_prog.py inject_prog.py) copy(test_trace_tools.py test_trace_tools.py) copy(llfi_test.py llfi_test) copy(test_generate_makefile.py test_generate_makefile.py) +copy(test_ml_tools.py test_ml_tools.py) +copy(test_ml_models.py test_ml_models.py) +copy(test_instruction_duplication.py test_instruction_duplication.py) genCopy() diff --git a/test_suite/SCRIPTS/build_prog.py b/test_suite/SCRIPTS/build_prog.py index 911f5f47..79a75174 100644 --- a/test_suite/SCRIPTS/build_prog.py +++ b/test_suite/SCRIPTS/build_prog.py @@ -2,51 +2,57 @@ import os import sys -import shutil import yaml import subprocess + def build_prog(*prog_list): - r = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 - - if len(prog_list) == 0: - ## build all programs - cwd = os.path.abspath(os.path.curdir) - os.chdir(os.path.join(testsuite_dir, "PROGRAMS")) - p = subprocess.Popen(["make"]) - p.wait() - if p.returncode != 0: - print ("ERROR: Failed in building all programs\n", file=sys.stderr) - r = p.returncode - os.chdir(cwd) - else: - cwd = os.path.abspath(os.path.curdir) - for prog in prog_list: - ## build prog - if prog not in list(suite["PROGRAMS"].keys()): - print("WARNING: program:", prog, "not defined in test_suite.yaml\n", file=sys.stderr) - continue - pd = os.path.join(testsuite_dir, "PROGRAMS", prog) - os.chdir(pd) - p = subprocess.Popen(["make"]) - p.wait() - if p.returncode != 0: - print ("ERROR: Failed in building program:", prog, file=sys.stderr) - r = p.returncode - - os.chdir(cwd) - - return r + r = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + if len(prog_list) == 0: + ## build all programs + cwd = os.path.abspath(os.path.curdir) + os.chdir(os.path.join(testsuite_dir, "PROGRAMS")) + p = subprocess.Popen(["make"]) + p.wait() + if p.returncode != 0: + print("ERROR: Failed in building all programs\n", file=sys.stderr) + r = p.returncode + os.chdir(cwd) + else: + cwd = os.path.abspath(os.path.curdir) + for prog in prog_list: + ## build prog + if prog not in list(suite["PROGRAMS"].keys()): + print( + "WARNING: program:", + prog, + "not defined in test_suite.yaml\n", + file=sys.stderr, + ) + continue + pd = os.path.join(testsuite_dir, "PROGRAMS", prog) + os.chdir(pd) + p = subprocess.Popen(["make"]) + p.wait() + if p.returncode != 0: + print("ERROR: Failed in building program:", prog, file=sys.stderr) + r = p.returncode + + os.chdir(cwd) + + return r + if __name__ == "__main__": - r = build_prog(*sys.argv[1:]) - sys.exit(r) \ No newline at end of file + r = build_prog(*sys.argv[1:]) + sys.exit(r) diff --git a/test_suite/SCRIPTS/check_injection.py b/test_suite/SCRIPTS/check_injection.py index 07373b18..240335f5 100644 --- a/test_suite/SCRIPTS/check_injection.py +++ b/test_suite/SCRIPTS/check_injection.py @@ -2,113 +2,114 @@ import os import sys -import shutil import yaml -import subprocess -def examineTraceFile(work_dir): - try: - inputyaml = open(os.path.join(work_dir, 'input.yaml'), 'r') - except: - print ("FAIL: (ERROR) input.yaml not found! work_dir:", work_dir) - return False - config_dict = yaml.safe_load(inputyaml) - try: - if config_dict['compileOption']['tracingPropagation'] == True: - ## we should have trace file - tracefile = os.path.join(work_dir, 'llfi', 'baseline', 'llfi.stat.trace.prof.txt') - if os.path.isfile(tracefile) and os.path.getsize(tracefile): - return True - else: - return False - else: - ## Tracing option disabled, pass - return True - except: - ## Tracing option disabled, pass - return True +def examineTraceFile(work_dir): + try: + with open(os.path.join(work_dir, "input.yaml"), "r") as inputyaml: + config_dict = yaml.safe_load(inputyaml) + except OSError: + print("FAIL: (ERROR) input.yaml not found! work_dir:", work_dir) + return False + try: + if config_dict["compileOption"]["tracingPropagation"]: + ## we should have trace file + tracefile = os.path.join( + work_dir, "llfi", "baseline", "llfi.stat.trace.prof.txt" + ) + if os.path.isfile(tracefile) and os.path.getsize(tracefile): + return True + else: + return False + else: + ## Tracing option disabled, pass + return True + except Exception: + ## Tracing option disabled, pass + return True def checkLLFIDir(work_dir, target_IR, prog_input): - llfi_dir = os.path.join(work_dir, "llfi") - if os.path.isdir(llfi_dir) == False: - return "FAIL: No ./llfi folder found!" - stats_dir = os.path.join(llfi_dir, "llfi_stat_output") - if os.path.isdir(stats_dir) == False: - return "FAIL: No ./llfi/llfi_stat_output folder found!" - baseline_dir = os.path.join(llfi_dir, "baseline") - if os.path.isdir(baseline_dir) == False: - return "FAIL: No ./llfi/baseline folder found!" - prog_output_dir = os.path.join(llfi_dir, "prog_output") - if os.path.isdir(prog_output_dir) == False: - return "FAIL: No ./llfi/prog_output folder found!" - std_output_dir = os.path.join(llfi_dir, "std_output") - if os.path.isdir(std_output_dir) == False: - return "FAIL: No ./llfi/std_output folder found!" + llfi_dir = os.path.join(work_dir, "llfi") + if not os.path.isdir(llfi_dir): + return "FAIL: No ./llfi folder found!" + stats_dir = os.path.join(llfi_dir, "llfi_stat_output") + if not os.path.isdir(stats_dir): + return "FAIL: No ./llfi/llfi_stat_output folder found!" + baseline_dir = os.path.join(llfi_dir, "baseline") + if not os.path.isdir(baseline_dir): + return "FAIL: No ./llfi/baseline folder found!" + prog_output_dir = os.path.join(llfi_dir, "prog_output") + if not os.path.isdir(prog_output_dir): + return "FAIL: No ./llfi/prog_output folder found!" + std_output_dir = os.path.join(llfi_dir, "std_output") + if not os.path.isdir(std_output_dir): + return "FAIL: No ./llfi/std_output folder found!" - stats = [f for f in os.listdir(stats_dir)] - if len(stats) == 0: - return "FAIL: No stats file found!" + stats = [f for f in os.listdir(stats_dir)] + if len(stats) == 0: + return "FAIL: No stats file found!" - if examineTraceFile(work_dir) == False: - return "FAIL: Tracing was enabled byt trace file not generated!" + if not examineTraceFile(work_dir): + return "FAIL: Tracing was enabled byt trace file not generated!" - return "PASS" + return "PASS" def check_injection(*prog_list): - r = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 + r = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + work_dict = {} + for test in suite.get("HardwareFaults", {}): + if len(prog_list) == 0 or test in prog_list or "HardwareFaults" in prog_list: + work_dict["./HardwareFaults/" + test] = suite["HardwareFaults"][test] + for test in suite.get("BatchMode", {}): + if len(prog_list) == 0 or test in prog_list or "BatchMode" in prog_list: + work_dict["./BatchMode/" + test] = suite["BatchMode"][test] + + result_list = [] + for test_path in work_dict: + inject_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) + inject_prog = suite["PROGRAMS"][work_dict[test_path]][0] + inject_input = str(suite["INPUTS"][work_dict[test_path]]) + if test_path.startswith("./BatchMode"): + # print("\tChecking on BatchMode:", test_path) + models = [ + m + for m in os.listdir(inject_dir) + if os.path.isdir(os.path.join(inject_dir, m)) + ] + for m in models: + subdir = os.path.join(inject_dir, m) + # print("\t\tChecking on model:", m) + result = checkLLFIDir(subdir, inject_prog, inject_input) + if result != "PASS": + break + if len(models) == 0: + result = "Subdirectories for failure modes not found!" + else: + result = checkLLFIDir(inject_dir, inject_prog, inject_input) + if result != "PASS": + r += 1 + record = {"name": test_path, "result": result} + result_list.append(record) - work_dict = {} - for test in suite.get("SoftwareFaults", {}): - if len(prog_list) == 0 or test in prog_list or "SoftwareFaults" in prog_list: - work_dict["./SoftwareFaults/"+test] = suite["SoftwareFaults"][test] - for test in suite.get("HardwareFaults", {}): - if len(prog_list) == 0 or test in prog_list or "HardwareFaults" in prog_list: - work_dict["./HardwareFaults/"+test] = suite["HardwareFaults"][test] - for test in suite.get("BatchMode", {}): - if len(prog_list) == 0 or test in prog_list or "BatchMode" in prog_list: - work_dict["./BatchMode/"+test] = suite["BatchMode"][test] - - - result_list = [] - for test_path in work_dict: - inject_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) - inject_prog = suite["PROGRAMS"][work_dict[test_path]][0] - inject_input = str(suite["INPUTS"][work_dict[test_path]]) - if test_path.startswith('./BatchMode'): - # print("\tChecking on BatchMode:", test_path) - models = [m for m in os.listdir(inject_dir) if os.path.isdir(os.path.join(inject_dir, m))] - for m in models: - subdir = os.path.join(inject_dir, m) - # print("\t\tChecking on model:", m) - result = checkLLFIDir(subdir, inject_prog, inject_input) - if result != "PASS": - break - if len(models) == 0: - result = "Subdirectories for failure modes not found!" - else: - result = checkLLFIDir(inject_dir, inject_prog, inject_input) - if result != "PASS": - r += 1 - record = {"name":test_path, "result":result} - result_list.append(record) + return r, result_list - return r, result_list if __name__ == "__main__": - r, result_list = check_injection(*sys.argv[1:]) - print ("=============== Result ===============") - for record in result_list: - print(record["name"], "\t\t", record["result"]) - sys.exit(r) \ No newline at end of file + r, result_list = check_injection(*sys.argv[1:]) + print("=============== Result ===============") + for record in result_list: + print(record["name"], "\t\t", record["result"]) + sys.exit(r) diff --git a/test_suite/SCRIPTS/clean_prog.py b/test_suite/SCRIPTS/clean_prog.py index e211ae6f..a56b1055 100644 --- a/test_suite/SCRIPTS/clean_prog.py +++ b/test_suite/SCRIPTS/clean_prog.py @@ -2,51 +2,57 @@ import os import sys -import shutil import yaml import subprocess + def clean_prog(*prog_list): - r = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 - - if len(prog_list) == 0: - ## clean all programs - cwd = os.path.abspath(os.path.curdir) - os.chdir(os.path.join(testsuite_dir, "PROGRAMS")) - p = subprocess.Popen(["make", "clean"]) - p.wait() - if p.returncode != 0: - print ("ERROR: Failed in cleaning all programs\n", file=sys.stderr) - r = p.returncode - os.chdir(cwd) - else: - cwd = os.path.abspath(os.path.curdir) - for prog in prog_list: - ## clean prog - if prog not in list(suite["PROGRAMS"].keys()): - print("WARNING: program:", prog, "not defined in test_suite.yaml\n", file=sys.stderr) - continue - pd = os.path.join(testsuite_dir, "PROGRAMS", prog) - os.chdir(pd) - p = subprocess.Popen(["make", "clean"]) - p.wait() - if p.returncode != 0: - print ("ERROR: Failed in cleaning program:", prog, file=sys.stderr) - r = p.returncode - - os.chdir(cwd) - - return r + r = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + if len(prog_list) == 0: + ## clean all programs + cwd = os.path.abspath(os.path.curdir) + os.chdir(os.path.join(testsuite_dir, "PROGRAMS")) + p = subprocess.Popen(["make", "clean"]) + p.wait() + if p.returncode != 0: + print("ERROR: Failed in cleaning all programs\n", file=sys.stderr) + r = p.returncode + os.chdir(cwd) + else: + cwd = os.path.abspath(os.path.curdir) + for prog in prog_list: + ## clean prog + if prog not in list(suite["PROGRAMS"].keys()): + print( + "WARNING: program:", + prog, + "not defined in test_suite.yaml\n", + file=sys.stderr, + ) + continue + pd = os.path.join(testsuite_dir, "PROGRAMS", prog) + os.chdir(pd) + p = subprocess.Popen(["make", "clean"]) + p.wait() + if p.returncode != 0: + print("ERROR: Failed in cleaning program:", prog, file=sys.stderr) + r = p.returncode + + os.chdir(cwd) + + return r + if __name__ == "__main__": - r = clean_prog(*sys.argv[1:]) - sys.exit(r) \ No newline at end of file + r = clean_prog(*sys.argv[1:]) + sys.exit(r) diff --git a/test_suite/SCRIPTS/clear_all.py b/test_suite/SCRIPTS/clear_all.py index 15364f34..36979b00 100644 --- a/test_suite/SCRIPTS/clear_all.py +++ b/test_suite/SCRIPTS/clear_all.py @@ -5,52 +5,49 @@ import shutil import yaml + def clear_all(): - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 - - ## clear hardware faults - for test in suite["HardwareFaults"]: - fs = [f for f in os.listdir(os.path.join(testsuite_dir, "HardwareFaults", test))] - for f in fs: - if f != "input.yaml": - print("MSG: Removing ", "HardwareFaults/"+test+"/"+f) - if os.path.isdir(os.path.join(testsuite_dir, "HardwareFaults", test, f)): - shutil.rmtree(os.path.join(testsuite_dir, "HardwareFaults", test, f)) - else: - os.remove(os.path.join(testsuite_dir, "HardwareFaults", test, f)) - - ## clear software faults - for test in suite["SoftwareFaults"]: - fs = [f for f in os.listdir(os.path.join(testsuite_dir, "SoftwareFaults", test))] - for f in fs: - if f != "input.yaml": - print("MSG: Removing ", "SoftwareFaults/"+test+"/"+f) - if os.path.isdir(os.path.join(testsuite_dir, "SoftwareFaults", test, f)): - shutil.rmtree(os.path.join(testsuite_dir, "SoftwareFaults", test, f)) - else: - os.remove(os.path.join(testsuite_dir, "SoftwareFaults", test, f)) - - ## clear batch mode faults - for test in suite["BatchMode"]: - fs = [f for f in os.listdir(os.path.join(testsuite_dir, "BatchMode", test))] - for f in fs: - if f != "input.yaml": - print("MSG: Removing ", "BatchMode/"+test+"/"+f) - if os.path.isdir(os.path.join(testsuite_dir, "BatchMode", test, f)): - shutil.rmtree(os.path.join(testsuite_dir, "BatchMode", test, f)) - else: - os.remove(os.path.join(testsuite_dir, "BatchMode", test, f)) - - return 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + ## clear hardware faults + for test in suite["HardwareFaults"]: + fs = [ + f for f in os.listdir(os.path.join(testsuite_dir, "HardwareFaults", test)) + ] + for f in fs: + if f != "input.yaml": + print("MSG: Removing ", "HardwareFaults/" + test + "/" + f) + if os.path.isdir( + os.path.join(testsuite_dir, "HardwareFaults", test, f) + ): + shutil.rmtree( + os.path.join(testsuite_dir, "HardwareFaults", test, f) + ) + else: + os.remove(os.path.join(testsuite_dir, "HardwareFaults", test, f)) + + ## clear batch mode faults + for test in suite["BatchMode"]: + fs = [f for f in os.listdir(os.path.join(testsuite_dir, "BatchMode", test))] + for f in fs: + if f != "input.yaml": + print("MSG: Removing ", "BatchMode/" + test + "/" + f) + if os.path.isdir(os.path.join(testsuite_dir, "BatchMode", test, f)): + shutil.rmtree(os.path.join(testsuite_dir, "BatchMode", test, f)) + else: + os.remove(os.path.join(testsuite_dir, "BatchMode", test, f)) + + return 0 + if __name__ == "__main__": - r = clear_all() - sys.exit(r) \ No newline at end of file + r = clear_all() + sys.exit(r) diff --git a/test_suite/SCRIPTS/clear_llfi.py b/test_suite/SCRIPTS/clear_llfi.py index 9edf58b7..d329ddcd 100644 --- a/test_suite/SCRIPTS/clear_llfi.py +++ b/test_suite/SCRIPTS/clear_llfi.py @@ -5,52 +5,49 @@ import shutil import yaml + def clear_llfi(): - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 - - ## clear hardware faults - for test in suite["HardwareFaults"]: - fs = [f for f in os.listdir(os.path.join(testsuite_dir, "HardwareFaults", test))] - for f in fs: - if f.startswith("llfi"): - print("MSG: Removing ", "HardwareFaults/"+test+"/"+f) - if os.path.isdir(os.path.join(testsuite_dir, "HardwareFaults", test, f)): - shutil.rmtree(os.path.join(testsuite_dir, "HardwareFaults", test, f)) - else: - os.remove(os.path.join(testsuite_dir, "HardwareFaults", test, f)) - - ## clear software faults - for test in suite["SoftwareFaults"]: - fs = [f for f in os.listdir(os.path.join(testsuite_dir, "SoftwareFaults", test))] - for f in fs: - if f.startswith("llfi"): - print("MSG: Removing ", "SoftwareFaults/"+test+"/"+f) - if os.path.isdir(os.path.join(testsuite_dir, "SoftwareFaults", test, f)): - shutil.rmtree(os.path.join(testsuite_dir, "SoftwareFaults", test, f)) - else: - os.remove(os.path.join(testsuite_dir, "SoftwareFaults", test, f)) - - ## clear software faults - for test in suite["BatchMode"]: - fs = [f for f in os.listdir(os.path.join(testsuite_dir, "BatchMode", test))] - for f in fs: - if f.startswith("llfi"): - print("MSG: Removing ", "BatchMode/"+test+"/"+f) - if os.path.isdir(os.path.join(testsuite_dir, "BatchMode", test, f)): - shutil.rmtree(os.path.join(testsuite_dir, "BatchMode", test, f)) - else: - os.remove(os.path.join(testsuite_dir, "BatchMode", test, f)) - - return 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + ## clear hardware faults + for test in suite["HardwareFaults"]: + fs = [ + f for f in os.listdir(os.path.join(testsuite_dir, "HardwareFaults", test)) + ] + for f in fs: + if f.startswith("llfi"): + print("MSG: Removing ", "HardwareFaults/" + test + "/" + f) + if os.path.isdir( + os.path.join(testsuite_dir, "HardwareFaults", test, f) + ): + shutil.rmtree( + os.path.join(testsuite_dir, "HardwareFaults", test, f) + ) + else: + os.remove(os.path.join(testsuite_dir, "HardwareFaults", test, f)) + + ## clear batch mode faults + for test in suite["BatchMode"]: + fs = [f for f in os.listdir(os.path.join(testsuite_dir, "BatchMode", test))] + for f in fs: + if f.startswith("llfi"): + print("MSG: Removing ", "BatchMode/" + test + "/" + f) + if os.path.isdir(os.path.join(testsuite_dir, "BatchMode", test, f)): + shutil.rmtree(os.path.join(testsuite_dir, "BatchMode", test, f)) + else: + os.remove(os.path.join(testsuite_dir, "BatchMode", test, f)) + + return 0 + if __name__ == "__main__": - r = clear_llfi() - sys.exit(r) \ No newline at end of file + r = clear_llfi() + sys.exit(r) diff --git a/test_suite/SCRIPTS/deploy_prog.py b/test_suite/SCRIPTS/deploy_prog.py index 0ba41f4a..76a9be6e 100644 --- a/test_suite/SCRIPTS/deploy_prog.py +++ b/test_suite/SCRIPTS/deploy_prog.py @@ -4,48 +4,52 @@ import sys import shutil import yaml -import subprocess + def deploy_prog(*prog_list): - r = 0 - copied = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 + r = 0 + copied = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + work_dict = {} + for test in suite.get("HardwareFaults", {}): + if len(prog_list) == 0 or test in prog_list or "HardwareFaults" in prog_list: + work_dict["./HardwareFaults/" + test] = suite["HardwareFaults"][test] + for test in suite.get("BatchMode", {}): + if len(prog_list) == 0 or test in prog_list or "BatchMode" in prog_list: + work_dict["./BatchMode/" + test] = suite["BatchMode"][test] + + for test_path in work_dict: + src_dir = os.path.join(testsuite_dir, "PROGRAMS", work_dict[test_path]) + req_files = [f for f in suite["PROGRAMS"][work_dict[test_path]]] + dst_dir = os.path.join(testsuite_dir, test_path) + for f in req_files: + src_path = os.path.join(src_dir, f) + try: + shutil.copy(src_path, dst_dir) + except Exception: + print( + "ERROR: Failed in copying program files:", + work_dict[test_path], + "for test:", + test_path, + file=sys.stderr, + ) + r += 1 + else: + copied += 1 + print("MSG:", copied, "files copied\n") + return r - work_dict = {} - for test in suite.get("SoftwareFaults", {}): - if len(prog_list) == 0 or test in prog_list or "SoftwareFaults" in prog_list: - work_dict["./SoftwareFaults/"+test] = suite["SoftwareFaults"][test] - for test in suite.get("HardwareFaults", {}): - if len(prog_list) == 0 or test in prog_list or "HardwareFaults" in prog_list: - work_dict["./HardwareFaults/"+test] = suite["HardwareFaults"][test] - for test in suite.get("BatchMode", {}): - if len(prog_list) == 0 or test in prog_list or "BatchMode" in prog_list: - work_dict["./BatchMode/"+test] = suite["BatchMode"][test] - - for test_path in work_dict: - src_dir = os.path.join(testsuite_dir, "PROGRAMS", work_dict[test_path]) - req_files = [f for f in suite["PROGRAMS"][work_dict[test_path]]] - dst_dir = os.path.join(testsuite_dir, test_path) - for f in req_files: - src_path = os.path.join(src_dir, f) - try: - shutil.copy(src_path, dst_dir) - except: - print ("ERROR: Failed in copying program files:", work_dict[test_path], "for test:", test_path, file=sys.stderr) - r += 1 - else: - copied += 1 - print ("MSG:", copied, "files copied\n") - return r if __name__ == "__main__": - r = deploy_prog(*sys.argv[1:]) - sys.exit(r) \ No newline at end of file + r = deploy_prog(*sys.argv[1:]) + sys.exit(r) diff --git a/test_suite/SCRIPTS/inject_prog.py b/test_suite/SCRIPTS/inject_prog.py index ddd18176..17547b1e 100644 --- a/test_suite/SCRIPTS/inject_prog.py +++ b/test_suite/SCRIPTS/inject_prog.py @@ -2,17 +2,16 @@ import os import sys -import shutil import yaml import subprocess import time -from threading import Thread +from threading import Thread try: - from Queue import Queue, Empty + from Queue import Queue, Empty except ImportError: - from queue import Queue, Empty # python 3.x -ON_POSIX = 'posix' in sys.builtin_module_names + from queue import Queue, Empty # python 3.x +ON_POSIX = "posix" in sys.builtin_module_names instrument_script = "" profile_script = "" @@ -20,251 +19,261 @@ batchinstrument_script = "" batchprofile_script = "" batchinjectfault_script = "" -autoscan_script = "" + def enqueue_output(out, queue): - for line in iter(out.readline, b''): - queue.put(line) - out.close() + for line in iter(out.readline, b""): + queue.put(line) + out.close() + def startEchoServer(work_dir): - print("using startEchoServer") - execlist = ["stdbuf", '-i0', '-o0', '-e0'] - execlist.extend([os.path.join(work_dir, "echoServer.exe")]) - server = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - q = Queue() - t = Thread(target=enqueue_output, args=(server.stdout, q)) - t.daemon = True # thread dies with the program - t.start() - count = 0 - while server.poll() == None: - if count > 50: - server.terminate() - return startEchoServer(work_dir) - try: line = q.get_nowait() # or q.get(timeout=.1) - except Empty: - print('no output yet') - count += 1 - time.sleep(1) - else: - print (line) - line = str(line) - if "Server running...waiting for connections." in line: - return server - else: - count += 1 - time.sleep(1) + print("using startEchoServer") + execlist = ["stdbuf", "-i0", "-o0", "-e0"] + execlist.extend([os.path.join(work_dir, "echoServer.exe")]) + server = subprocess.Popen( + execlist, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + q = Queue() + t = Thread(target=enqueue_output, args=(server.stdout, q)) + t.daemon = True # thread dies with the program + t.start() + count = 0 + while server.poll() is None: + if count > 50: + server.terminate() + return startEchoServer(work_dir) + try: + line = q.get_nowait() # or q.get(timeout=.1) + except Empty: + print("no output yet") + count += 1 + time.sleep(1) + else: + print(line) + line = str(line) + if "Server running...waiting for connections." in line: + return server + else: + count += 1 + time.sleep(1) + def callLLFI(work_dir, target_IR, prog_input): - global instrument_script - global profile_script - global injectfault_script + try: + os.chdir(work_dir) + except Exception: + print("ERROR: Unable to change directory to:", work_dir) + return -1, None + with open("llfi.test.log.instrument.txt", "w", buffering=1) as log: + p = subprocess.Popen( + [instrument_script, "--readable", "-lpthread", target_IR], + stdout=log, + stderr=log, + ) + p.wait() + if p.returncode != 0: + print("ERROR: instrument failed for:", work_dir, target_IR) + return -1, None + else: + print("MSG: instrument succeed for:", work_dir, target_IR) + + with open("llfi.test.log.profile.txt", "w", buffering=1) as log: + if target_IR == "echoClient.ll": + server = startEchoServer(work_dir) + print( + "MSG: echoServer.ll started for profile, please make sure there is only one echoServer running\n" + ) + time.sleep(2) + profile_exe = target_IR.split(".ll")[0] + "-profiling.exe" + execlist = [profile_script, "./llfi/" + profile_exe] + execlist.extend(prog_input.split(" ")) + p = subprocess.Popen(execlist, stdout=log, stderr=log) + p.wait() + if target_IR == "echoClient.ll": + try: + server.terminate() + print("MSG: echoServer.exe terminated for profile.\n") + except Exception: + print( + "ERROR: Unable to terminate echoServer.exe in profile for:", + work_dir, + ) + if p.returncode != 0: + print("ERROR: profile failed for:", work_dir, target_IR) + return -1, None + else: + print("MSG: profile succeed for:", work_dir, target_IR, prog_input) - try: - os.chdir(work_dir) - except: - print ("ERROR: Unable to change directory to:", work_dir) - return -1, None - with open("llfi.test.log.instrument.txt", 'w', buffering=1) as log: - p = subprocess.Popen([instrument_script, "--readable", "-lpthread", target_IR], stdout=log, stderr=log) - p.wait() - if p.returncode != 0: - print ("ERROR: instrument failed for:", work_dir, target_IR) - return -1, None - else: - print ("MSG: instrument succeed for:", work_dir, target_IR) + with open("llfi.test.log.injectFault.txt", "w", buffering=1) as log: + if target_IR == "echoClient.ll": + server = startEchoServer(work_dir) + print( + "MSG: echoServer.ll started for injectfault, please make sure there is only one echoServer running\n" + ) + time.sleep(2) + faultinjection_exe = target_IR.split(".ll")[0] + "-faultinjection.exe" + execlist = [injectfault_script, "./llfi/" + faultinjection_exe] + execlist.extend(prog_input.split(" ")) + p = subprocess.Popen(execlist, stdout=log, stderr=log) + t = {"name": " ".join(work_dir.split("/")[-3:]) + "/" + target_IR, "process": p} + if target_IR == "echoClient.ll": + p.wait() + try: + server.terminate() + print("MSG: echoServer.exe terminated for profile.\n") + except Exception: + print( + "ERROR: Unable to terminate echoServer.exe in injectfault for", + work_dir, + ) - with open("llfi.test.log.profile.txt", 'w', buffering=1) as log: - if target_IR == "echoClient.ll": - server = startEchoServer(work_dir) - print ("MSG: echoServer.ll started for profile, please make sure there is only one echoServer running\n") - time.sleep(2) - profile_exe = target_IR.split(".ll")[0]+"-profiling.exe" - execlist = [profile_script, "./llfi/"+profile_exe] - execlist.extend(prog_input.split(' ')) - p = subprocess.Popen(execlist, stdout=log, stderr=log) - p.wait() - if target_IR == "echoClient.ll": - try: - server.terminate() - print ("MSG: echoServer.exe terminated for profile.\n") - except: - print ("ERROR: Unable to terminate echoServer.exe in profile for:", work_dir) - if p.returncode != 0: - print ("ERROR: profile failed for:", work_dir, target_IR) - return -1, None - else: - print ("MSG: profile succeed for:", work_dir, target_IR, prog_input) + return 0, t - with open("llfi.test.log.injectFault.txt", 'w', buffering=1) as log: - if target_IR == "echoClient.ll": - server = startEchoServer(work_dir) - print ("MSG: echoServer.ll started for injectfault, please make sure there is only one echoServer running\n") - time.sleep(2) - faultinjection_exe = target_IR.split(".ll")[0]+"-faultinjection.exe" - execlist = [injectfault_script, "./llfi/"+faultinjection_exe] - execlist.extend(prog_input.split(' ')) - p = subprocess.Popen(execlist, stdout=log, stderr=log) - t = {"name":' '.join(work_dir.split('/')[-3:])+"/"+target_IR, - "process":p} - if target_IR == "echoClient.ll": - p.wait() - try: - server.terminate() - print ("MSG: echoServer.exe terminated for profile.\n") - except: - print ("ERROR: Unable to terminate echoServer.exe in injectfault for", work_dir) - - return 0, t def callBatchLLFI(work_dir, target_IR, prog_input): - global batchinstrument_script - global batchprofile_script - global batchinjectfault_script - global autoscan_script + try: + os.chdir(work_dir) + except Exception: + print("ERROR: Unable to change directory to:", work_dir) + return -1, None - try: - os.chdir(work_dir) - except: - print ("ERROR: Unable to change directory to:", work_dir) - return -1, None + with open("llfi.test.log.instrument.txt", "w", buffering=1) as log: + p = subprocess.Popen( + [batchinstrument_script, "--readable", "-lpthread", target_IR], + stdout=log, + stderr=log, + ) + p.wait() + if p.returncode != 0: + print("ERROR: batchInstrument failed for:", work_dir, target_IR) + return -1, None + else: + print("MSG: batchInstrument succeed for:", work_dir, target_IR) - if 'SoftwareFailureAutoScan' in os.path.basename(work_dir): - with open("llfi.test.log.SoftwareFailureAutoScan.txt", 'w', buffering=1) as log: - p = subprocess.Popen([autoscan_script, target_IR], stdout=log, stderr=log) - p.wait() - if p.returncode != 0: - print ("ERROR: SoftwareFailureAutoScan failed for:", work_dir, target_IR) - return -1, None - else: - print ("MSG: SoftwareFailureAutoScan succeed for:", work_dir, target_IR) + with open("llfi.test.log.profile.txt", "w", buffering=1) as log: + if target_IR == "echoClient.ll": + server = startEchoServer(work_dir) + print( + "MSG: echoServer.ll started for profile, please make sure there is only one echoServer running\n" + ) + time.sleep(2) + execlist = [batchprofile_script, target_IR] + execlist.extend(prog_input.split(" ")) + p = subprocess.Popen(execlist, stdout=log, stderr=log) + p.wait() + if target_IR == "echoClient.ll": + try: + server.terminate() + print("MSG: echoServer.exe terminated for profile.\n") + except Exception: + print( + "ERROR: Unable to terminate echoServer.exe in profile for:", + work_dir, + ) + if p.returncode != 0: + print("ERROR: profile failed for:", work_dir, target_IR) + return -1, None + else: + print("MSG: profile succeed for:", work_dir, target_IR, prog_input) - with open("llfi.test.log.instrument.txt", 'w', buffering=1) as log: - p = subprocess.Popen([batchinstrument_script, "--readable", "-lpthread", target_IR], stdout=log, stderr=log) - p.wait() - if p.returncode != 0: - print ("ERROR: batchInstrument failed for:", work_dir, target_IR) - return -1, None - else: - print ("MSG: batchInstrument succeed for:", work_dir, target_IR) + with open("llfi.test.log.injectFault.txt", "w", buffering=1) as log: + if target_IR == "echoClient.ll": + server = startEchoServer(work_dir) + print( + "MSG: echoServer.ll started for injectfault, please make sure there is only one echoServer running\n" + ) + time.sleep(2) + execlist = [batchinjectfault_script, target_IR] + execlist.extend(prog_input.split(" ")) + p = subprocess.Popen(execlist, stdout=log, stderr=log) + t = {"name": " ".join(work_dir.split("/")[-3:]) + "/" + target_IR, "process": p} + if target_IR == "echoClient.ll": + p.wait() + try: + server.terminate() + print("MSG: echoServer.exe terminated for profile.\n") + except Exception: + print( + "ERROR: Unable to terminate echoServer.exe in injectfault for", + work_dir, + ) - with open("llfi.test.log.profile.txt", 'w', buffering=1) as log: - if target_IR == "echoClient.ll": - server = startEchoServer(work_dir) - print ("MSG: echoServer.ll started for profile, please make sure there is only one echoServer running\n") - time.sleep(2) - execlist = [batchprofile_script, target_IR] - execlist.extend(prog_input.split(' ')) - p = subprocess.Popen(execlist, stdout=log, stderr=log) - p.wait() - if target_IR == "echoClient.ll": - try: - server.terminate() - print ("MSG: echoServer.exe terminated for profile.\n") - except: - print ("ERROR: Unable to terminate echoServer.exe in profile for:", work_dir) - if p.returncode != 0: - print ("ERROR: profile failed for:", work_dir, target_IR) - return -1, None - else: - print ("MSG: profile succeed for:", work_dir, target_IR, prog_input) + return 0, t - with open("llfi.test.log.injectFault.txt", 'w', buffering=1) as log: - if target_IR == "echoClient.ll": - server = startEchoServer(work_dir) - print ("MSG: echoServer.ll started for injectfault, please make sure there is only one echoServer running\n") - time.sleep(2) - execlist = [batchinjectfault_script, target_IR] - execlist.extend(prog_input.split(' ')) - p = subprocess.Popen(execlist, stdout=log, stderr=log) - t = {"name":' '.join(work_dir.split('/')[-3:])+"/"+target_IR, - "process":p} - if target_IR == "echoClient.ll": - p.wait() - try: - server.terminate() - print ("MSG: echoServer.exe terminated for profile.\n") - except: - print ("ERROR: Unable to terminate echoServer.exe in injectfault for", work_dir) - - return 0, t def inject_prog(num_threads, *prog_list): - global instrument_script - global profile_script - global injectfault_script - global batchinstrument_script - global batchprofile_script - global batchinjectfault_script - global autoscan_script + global instrument_script + global profile_script + global injectfault_script + global batchinstrument_script + global batchprofile_script + global batchinjectfault_script + + r = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + llfi_bin_dir = os.path.join(script_dir, "../../bin") + instrument_script = os.path.join(llfi_bin_dir, "instrument") + profile_script = os.path.join(llfi_bin_dir, "profile") + injectfault_script = os.path.join(llfi_bin_dir, "injectfault") + batchinstrument_script = os.path.join(llfi_bin_dir, "batchInstrument") + batchprofile_script = os.path.join(llfi_bin_dir, "batchProfile") + batchinjectfault_script = os.path.join(llfi_bin_dir, "batchInjectfault") + + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + work_dict = {} + for test in suite.get("HardwareFaults", {}): + if len(prog_list) == 0 or test in prog_list or "HardwareFaults" in prog_list: + work_dict["./HardwareFaults/" + test] = suite["HardwareFaults"][test] + for test in suite.get("BatchMode", {}): + if len(prog_list) == 0 or test in prog_list or "BatchMode" in prog_list: + work_dict["./BatchMode/" + test] = suite["BatchMode"][test] - r = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - llfi_bin_dir = os.path.join(script_dir, '../../bin') - instrument_script = os.path.join(llfi_bin_dir, "instrument") - profile_script = os.path.join(llfi_bin_dir, "profile") - injectfault_script = os.path.join(llfi_bin_dir, "injectfault") - batchinstrument_script = os.path.join(llfi_bin_dir, "batchInstrument") - batchprofile_script = os.path.join(llfi_bin_dir, "batchProfile") - batchinjectfault_script = os.path.join(llfi_bin_dir, "batchInjectfault") - autoscan_script = os.path.join(llfi_bin_dir, "SoftwareFailureAutoScan") - - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 + running_list = [] + exitcode_list = [] + for test_path in work_dict: + while len(running_list) >= num_threads: + for t in running_list: + if t["process"].poll() is None: + continue + else: + print("MSG: Injection for:", t["name"], "finished!\n") + running_list.remove(t) + record = {"name": t["name"], "exitcode": t["process"].returncode} + exitcode_list.append(record) - work_dict = {} - for test in suite.get("SoftwareFaults", {}): - if len(prog_list) == 0 or test in prog_list or "SoftwareFaults" in prog_list: - work_dict["./SoftwareFaults/"+test] = suite["SoftwareFaults"][test] - for test in suite.get("HardwareFaults", {}): - if len(prog_list) == 0 or test in prog_list or "HardwareFaults" in prog_list: - work_dict["./HardwareFaults/"+test] = suite["HardwareFaults"][test] - for test in suite.get("BatchMode", {}): - if len(prog_list) == 0 or test in prog_list or "BatchMode" in prog_list: - work_dict["./BatchMode/"+test] = suite["BatchMode"][test] - - running_list = [] - exitcode_list = [] - for test_path in work_dict: - while(len(running_list) >= num_threads): - for t in running_list: - if t["process"].poll() is None: - continue - else: - print ("MSG: Injection for:", t["name"], "finished!\n") - running_list.remove(t) - record={"name":t["name"], "exitcode":t["process"].returncode} - exitcode_list.append(record) + inject_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) + inject_prog = suite["PROGRAMS"][work_dict[test_path]][0] + inject_input = str(suite["INPUTS"][work_dict[test_path]]) + if test_path.startswith("./BatchMode"): + code, t = callBatchLLFI(inject_dir, inject_prog, inject_input) + else: + code, t = callLLFI(inject_dir, inject_prog, inject_input) + if code != 0: + print("ERROR: Skip:", test_path) + continue + running_list.append(t) - inject_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) - inject_prog = suite["PROGRAMS"][work_dict[test_path]][0] - inject_input = str(suite["INPUTS"][work_dict[test_path]]) - if test_path.startswith('./BatchMode'): - code, t = callBatchLLFI(inject_dir, inject_prog, inject_input) - else: - code, t = callLLFI(inject_dir, inject_prog, inject_input) - if code != 0: - print ("ERROR: Skip:", test_path) - continue - running_list.append(t) + while len(running_list) > 0: + for t in running_list: + if t["process"].poll() is None: + continue + else: + print("MSG: Injection for:", t["name"], "finished!\n") + running_list.remove(t) + record = {"name": t["name"], "exitcode": t["process"].returncode} + exitcode_list.append(record) + return r - while(len(running_list) > 0): - for t in running_list: - if t["process"].poll() is None: - continue - else: - print ("MSG: Injection for:", t["name"], "finished!\n") - running_list.remove(t) - record={"name":t["name"], "exitcode":t["process"].returncode} - exitcode_list.append(record) - return r if __name__ == "__main__": - r = inject_prog(int(sys.argv[1]), *sys.argv[2:]) - sys.exit(r) + r = inject_prog(int(sys.argv[1]), *sys.argv[2:]) + sys.exit(r) diff --git a/test_suite/SCRIPTS/llfi_test.py b/test_suite/SCRIPTS/llfi_test.py index 926654c3..6bba0c1e 100755 --- a/test_suite/SCRIPTS/llfi_test.py +++ b/test_suite/SCRIPTS/llfi_test.py @@ -2,20 +2,20 @@ """ -%(prog)s is a test suite driver script to run all the steps of LLFI regression test. +%(prog)s is a test suite driver script to run all the steps of LLFI regression test. Usage: %(prog)s [OPTIONS] List of options: --threads : number of threads to be used for fault injections, default value: 1. ---all: Test all the test cases of LLFI test suite, including fault injection tests, trace analysis tests and make file generation tests. ---all_fault_injections: Test all the test cases of fault injections, including HardwareFaults, SoftwareFaults and BatchMode tests. ---all_software_faults: Test all the test cases of SoftwareFaults. +--all_cpp: Test all the test cases of LLFI test suite, including fault injection tests, trace analysis tests and make file generation tests. +--all_fault_injections: Test all the test cases of fault injections, including HardwareFaults and BatchMode tests. --all_hardware_faults: Test all the test cases of HardwareFaults. --all_batchmode: Test all the test cases of BatchMode fault injections. --all_trace_tools_tests: Test all the tests for trace analysis tools. --all_makefile_generation: Test all the tests for makefile generation script. +--all_ml: Test ML/ONNX tools (CompareLayerOutputs, ExtendONNXModel, outputONNXGraph), TensorFlow/PyTorch ONNX pipelines, ONNX-to-LLVM-IR compilation, and ML fault injection. Tests that require optional dependencies (onnx, pygraphviz, pydot, tensorflow, tf2onnx, torch, onnx-mlir) are reported as SKIP when those packages are absent. --test_cases [test case names]: Test only specified test case. --clean_after_test: Clean all the generate files after testing. @@ -30,250 +30,305 @@ import time options = { - 'all':False, - 'all_fault_injections':False, - 'all_software_faults':False, - 'all_hardware_faults':False, - 'all_batchmode':False, - 'all_trace_tools_tests':False, - 'all_makefile_generation':False, - 'test_cases':[], - 'threads':1, - 'clean_after_test':False, + "all_cpp": False, + "all_fault_injections": False, + "all_hardware_faults": False, + "all_batchmode": False, + "all_trace_tools_tests": False, + "all_makefile_generation": False, + "all_ml": False, + "test_cases": [], + "threads": 1, + "clean_after_test": False, } prog = os.path.basename(sys.argv[0]) verbose = False + def verbosePrint(msg): - global verbose - if verbose: - print(msg) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + if verbose: + print(msg) + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) + def parseArgs(args): - global options - global verbose - argid = 0 - while argid < len(args): - arg = args[argid] - - if arg == "--all": - options['all'] = True - - elif arg == "--all_fault_injections": - options['all_fault_injections'] = True - - elif arg == "--all_software_faults": - options['all_software_faults'] = True - - elif arg == "--all_hardware_faults": - options['all_hardware_faults'] = True - - elif arg == "--all_batchmode": - options['all_batchmode'] = True - - elif arg == "--test_cases": - argid += 1 - while(argid < len(args) and args[argid][0] != '-'): - options['test_cases'].append(str(args[argid])) - argid += 1 - - elif arg == "--threads": - argid += 1 - options['threads'] = int(args[argid]) - - elif arg == "--all_trace_tools_tests": - options['all_trace_tools_tests'] = True - - elif arg == "--all_makefile_generation": - options['all_makefile_generation'] = True - - elif arg == "--clean_after_test": - options['clean_after_test'] = True - - elif arg == "--help" or arg == "-h": - usage() - - elif arg == "--verbose": - verbose = True - - argid += 1 + global verbose + argid = 0 + while argid < len(args): + arg = args[argid] + + if arg == "--all_cpp": + options["all_cpp"] = True + + elif arg == "--all_fault_injections": + options["all_fault_injections"] = True + + elif arg == "--all_hardware_faults": + options["all_hardware_faults"] = True + + elif arg == "--all_batchmode": + options["all_batchmode"] = True + + elif arg == "--test_cases": + argid += 1 + while argid < len(args) and args[argid][0] != "-": + options["test_cases"].append(str(args[argid])) + argid += 1 + + elif arg == "--threads": + argid += 1 + options["threads"] = int(args[argid]) + + elif arg == "--all_trace_tools_tests": + options["all_trace_tools_tests"] = True + + elif arg == "--all_makefile_generation": + options["all_makefile_generation"] = True + + elif arg == "--all_ml": + options["all_ml"] = True + + elif arg == "--clean_after_test": + options["clean_after_test"] = True + + elif arg == "--help" or arg == "-h": + usage() + + elif arg == "--verbose": + verbose = True + + argid += 1 + def startTestRoutine(): - global options - script_dir = os.path.dirname(os.path.realpath(__file__)) - sys.path.append(script_dir) - build_prog_script = os.path.join(script_dir, 'build_prog.py') - deploy_prog_script = os.path.join(script_dir, 'deploy_prog.py') - inject_prog_script = os.path.join(script_dir, 'inject_prog.py') - check_injection_script = os.path.join(script_dir, 'check_injection.py') - test_trace_tools_script = os.path.join(script_dir, 'test_trace_tools.py') - test_generate_makefile_script = os.path.join(script_dir, 'test_generate_makefile_script.py') - clear_all_script = os.path.join(script_dir, 'clear_all.py') - - injection_result_list = [] - trace_result_list = [] - generate_makefile_result_list = [] - - if options['all'] or options['all_batchmode'] or options['all_hardware_faults']\ - or options['all_software_faults'] or options['all_fault_injections']\ - or options['test_cases'] != []: - ## build all the test program - execlist = ['python3', '-u', build_prog_script] - verbosePrint(' '.join(execlist)) - p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - p.wait() - r = p.returncode - if r != 0: - print("ERROR: Failed in building all test programs") - sys.exit(-1) - else: - print("Build test programs successfully.") - - ## deploy programs - execlist = ['python3', '-u', deploy_prog_script] - if options['all_batchmode']: - execlist.append('BatchMode') - elif options['all_software_faults']: - execlist.append('SoftwareFaults') - elif options['all_hardware_faults']: - execlist.append('HardwareFaults') - elif options['test_cases'] != []: - execlist.extend(options['test_cases']) - elif options['all'] or options['all_fault_injections']: - pass - verbosePrint(' '.join(execlist)) - p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - p.wait() - r = p.returncode - if r != 0: - print("ERROR: Failed in deploy test programs") - sys.exit(-1) - else: - print("Deploy test programs successfully.") - - ## start fault injection - execlist = ['python3', '-u', inject_prog_script, str(options['threads'])] - if options['all_batchmode']: - execlist.append('BatchMode') - elif options['all_software_faults']: - execlist.append('SoftwareFaults') - elif options['all_hardware_faults']: - execlist.append('HardwareFaults') - elif options['test_cases'] != []: - execlist.extend(options['test_cases']) - elif options['all'] or options['all_fault_injections']: - pass - verbosePrint(' '.join(execlist)) - p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - p.wait() - r = p.returncode - if r != 0: - print("WARNING: error occurs during fault injection. Continue on testing.") - else: - print("Fault injection ends normally.") - - ## check the injection - import check_injection - prog_list = [] - if options['all_batchmode']: - prog_list.append('BatchMode') - elif options['all_software_faults']: - prog_list.append('SoftwareFaults') - elif options['all_hardware_faults']: - prog_list.append('HardwareFaults') - elif options['test_cases'] != []: - prog_list.extend(options['test_cases']) - elif options['all'] or options['all_fault_injections']: - pass - verbosePrint('Calling: check_injection.check_injection(' + ' '.join(prog_list) + ')') - check_injection_returncode, injection_result_list = check_injection.check_injection(*prog_list) - - ## run trace tools's tests - if options['all_trace_tools_tests'] or options['all'] or options['test_cases'] != []: - import test_trace_tools - prog_list = [] - if options['test_cases'] != []: - prog_list.extend(options['test_cases']) - elif options['all_trace_tools_tests'] or options['all']: - pass - verbosePrint('Calling: test_trace_tools.test_trace_tools(' + ' '.join(prog_list) + ')') - test_trace_tools_returncode, trace_result_list = test_trace_tools.test_trace_tools(*prog_list) - - ## run MakefileGeneration tests - if options['all_makefile_generation'] or options['all'] or options['test_cases'] != []: - import test_generate_makefile - prog_list = [] - if options['test_cases'] != []: - prog_list.extend(options['test_cases']) - elif options['all_makefile_generation'] or options['all']: - pass - verbosePrint('Calling: test_generate_makefile.test_generate_makefile(' + ' '.join(prog_list) + ')') - test_generate_makefile_returncode, generate_makefile_result_list = test_generate_makefile.test_generate_makefile(*prog_list) - - ## collect the results - total = 0 - passed = 0 - if len(injection_result_list) > 0: - print ("==== Check Injection Result ====") - for record in injection_result_list: - print(record["name"], "\t\t", record["result"]) - total += 1 - if record['result'] == 'PASS': - passed += 1 - if len(trace_result_list) > 0: - print ("==== Test Trace Tools Result ====") - for record in trace_result_list: - print(record["name"], "\t\t", record["result"]) - total += 1 - if record['result'] == 'PASS': - passed += 1 - - if len(generate_makefile_result_list) > 0: - print("==== Test MakefileGeneration Tool Result ====") - for record in generate_makefile_result_list: - print(record["name"], '\t\t', record["result"]) - total += 1 - if record['result'] == 'PASS': - passed += 1 - - print("=== Overall Counts ====") - print("Total tests:\t", total) - print("Passed tests:\t", passed) - print("Failed tests:\t", total - passed) - - if options['clean_after_test']: - execlist = ['python3', '-u', clear_all_script] - verbosePrint(' '.join(execlist)) - p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - p.wait() - os.chdir(os.path.join(script_dir, os.pardir, 'PROGRAMS')) - os.system('make clean') - dirs = [d for d in os.listdir(os.path.join(script_dir, os.pardir, 'MakefileGeneration')) - if os.path.isdir(os.path.join(script_dir, os.pardir, 'MakefileGeneration',d))] - print(dirs) - for d in dirs: - p = os.path.join(script_dir, os.pardir, 'MakefileGeneration', d) - os.chdir(p) - os.system('make clean') - - - return 0 + script_dir = os.path.dirname(os.path.realpath(__file__)) + sys.path.append(script_dir) + build_prog_script = os.path.join(script_dir, "build_prog.py") + deploy_prog_script = os.path.join(script_dir, "deploy_prog.py") + inject_prog_script = os.path.join(script_dir, "inject_prog.py") + clear_all_script = os.path.join(script_dir, "clear_all.py") + + injection_result_list = [] + trace_result_list = [] + generate_makefile_result_list = [] + ml_result_list = [] + + if ( + options["all_cpp"] + or options["all_batchmode"] + or options["all_hardware_faults"] + or options["all_fault_injections"] + or options["test_cases"] != [] + ): + ## build all the test program + execlist = ["python3", "-u", build_prog_script] + verbosePrint(" ".join(execlist)) + p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + r = p.returncode + if r != 0: + print("ERROR: Failed in building all test programs") + sys.exit(-1) + else: + print("Build test programs successfully.") + + ## deploy programs + execlist = ["python3", "-u", deploy_prog_script] + if options["all_batchmode"]: + execlist.append("BatchMode") + elif options["all_hardware_faults"]: + execlist.append("HardwareFaults") + elif options["test_cases"] != []: + execlist.extend(options["test_cases"]) + elif options["all_cpp"] or options["all_fault_injections"]: + pass + verbosePrint(" ".join(execlist)) + p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + r = p.returncode + if r != 0: + print("ERROR: Failed in deploy test programs") + sys.exit(-1) + else: + print("Deploy test programs successfully.") + + ## start fault injection + execlist = ["python3", "-u", inject_prog_script, str(options["threads"])] + if options["all_batchmode"]: + execlist.append("BatchMode") + elif options["all_hardware_faults"]: + execlist.append("HardwareFaults") + elif options["test_cases"] != []: + execlist.extend(options["test_cases"]) + elif options["all_cpp"] or options["all_fault_injections"]: + pass + verbosePrint(" ".join(execlist)) + p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + r = p.returncode + if r != 0: + print("WARNING: error occurs during fault injection. Continue on testing.") + else: + print("Fault injection ends normally.") + + ## check the injection + import check_injection + + prog_list = [] + if options["all_batchmode"]: + prog_list.append("BatchMode") + elif options["all_hardware_faults"]: + prog_list.append("HardwareFaults") + elif options["test_cases"] != []: + prog_list.extend(options["test_cases"]) + elif options["all_cpp"] or options["all_fault_injections"]: + pass + verbosePrint( + "Calling: check_injection.check_injection(" + " ".join(prog_list) + ")" + ) + check_injection_returncode, injection_result_list = ( + check_injection.check_injection(*prog_list) + ) + + ## run trace tools's tests + if ( + options["all_trace_tools_tests"] + or options["all_cpp"] + or options["test_cases"] != [] + ): + import test_trace_tools + + prog_list = [] + if options["test_cases"] != []: + prog_list.extend(options["test_cases"]) + elif options["all_trace_tools_tests"] or options["all_cpp"]: + pass + verbosePrint( + "Calling: test_trace_tools.test_trace_tools(" + " ".join(prog_list) + ")" + ) + test_trace_tools_returncode, trace_result_list = ( + test_trace_tools.test_trace_tools(*prog_list) + ) + + ## run ML/ONNX tools tests (not part of --all; requires optional deps) + if options["all_ml"]: + import test_ml_tools + import test_ml_models + import test_instruction_duplication + + verbosePrint("Calling: test_ml_tools.test_ml_tools()") + _, ml_tools_list = test_ml_tools.test_ml_tools() + verbosePrint("Calling: test_ml_models.test_ml_models()") + _, ml_models_list = test_ml_models.test_ml_models() + verbosePrint( + "Calling: test_instruction_duplication.test_instruction_duplication()" + ) + _, sid_list = test_instruction_duplication.test_instruction_duplication() + ml_result_list = ml_tools_list + ml_models_list + sid_list + + ## run MakefileGeneration tests + if ( + options["all_makefile_generation"] + or options["all_cpp"] + or options["test_cases"] != [] + ): + import test_generate_makefile + + prog_list = [] + if options["test_cases"] != []: + prog_list.extend(options["test_cases"]) + elif options["all_makefile_generation"] or options["all_cpp"]: + pass + verbosePrint( + "Calling: test_generate_makefile.test_generate_makefile(" + + " ".join(prog_list) + + ")" + ) + test_generate_makefile_returncode, generate_makefile_result_list = ( + test_generate_makefile.test_generate_makefile(*prog_list) + ) + + ## collect the results + total = 0 + passed = 0 + if len(injection_result_list) > 0: + print("==== Check Injection Result ====") + for record in injection_result_list: + print(record["name"], "\t\t", record["result"]) + total += 1 + if record["result"] == "PASS": + passed += 1 + if len(trace_result_list) > 0: + print("==== Test Trace Tools Result ====") + for record in trace_result_list: + print(record["name"], "\t\t", record["result"]) + total += 1 + if record["result"] == "PASS": + passed += 1 + + if len(generate_makefile_result_list) > 0: + print("==== Test MakefileGeneration Tool Result ====") + for record in generate_makefile_result_list: + print(record["name"], "\t\t", record["result"]) + total += 1 + if record["result"] == "PASS": + passed += 1 + + if len(ml_result_list) > 0: + print("==== Test ML/ONNX Tools Result ====") + for record in ml_result_list: + print(record["name"], "\t\t", record["result"]) + if record["result"].startswith("SKIP"): + continue # SKIPs are neither pass nor fail + total += 1 + if record["result"] == "PASS": + passed += 1 + + print("=== Overall Counts ====") + print("Total tests:\t", total) + print("Passed tests:\t", passed) + print("Failed tests:\t", total - passed) + + if options["clean_after_test"]: + execlist = ["python3", "-u", clear_all_script] + verbosePrint(" ".join(execlist)) + p = subprocess.Popen(execlist, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + os.chdir(os.path.join(script_dir, os.pardir, "PROGRAMS")) + os.system("make clean") + dirs = [ + d + for d in os.listdir( + os.path.join(script_dir, os.pardir, "MakefileGeneration") + ) + if os.path.isdir( + os.path.join(script_dir, os.pardir, "MakefileGeneration", d) + ) + ] + print(dirs) + for d in dirs: + p = os.path.join(script_dir, os.pardir, "MakefileGeneration", d) + os.chdir(p) + os.system("make clean") + + return 0 + if __name__ == "__main__": - if len(sys.argv) == 1: - usage() - parseArgs(sys.argv[1:]) - print("Tests Start on: ", time.ctime()) - r = startTestRoutine() - print("Tests Ends on: ", time.ctime()) - sys.exit(r) + if len(sys.argv) == 1: + usage() + parseArgs(sys.argv[1:]) + print("Tests Start on: ", time.ctime()) + r = startTestRoutine() + print("Tests Ends on: ", time.ctime()) + sys.exit(r) diff --git a/test_suite/SCRIPTS/test_generate_makefile.py b/test_suite/SCRIPTS/test_generate_makefile.py index d2a094cf..b9b63c39 100644 --- a/test_suite/SCRIPTS/test_generate_makefile.py +++ b/test_suite/SCRIPTS/test_generate_makefile.py @@ -2,99 +2,105 @@ import os import sys -import shutil import yaml import subprocess -generate_makefile_script = '' -llvm_interpreter_bin = '' +generate_makefile_script = "" +llvm_interpreter_bin = "" + def callGenerateMakefile(work_dir, resources): - global generate_makefile_script - - execlist = [generate_makefile_script, '--dir', work_dir] - execlist.extend(resources['makefile_generation_args'].split(' ')) - print(' '.join(execlist)) - p = subprocess.Popen(execlist) - p.wait() - if p.returncode != 0: - return ("FAIL: ERROR in calling " + generate_makefile_script + " on " + work_dir) - else: - return ("PASS") + execlist = [generate_makefile_script, "--dir", work_dir] + execlist.extend(resources["makefile_generation_args"].split(" ")) + print(" ".join(execlist)) + p = subprocess.Popen(execlist) + p.wait() + if p.returncode != 0: + return "FAIL: ERROR in calling " + generate_makefile_script + " on " + work_dir + else: + return "PASS" + def callLLVMInterpreter(work_dir, resources): - global llvm_interpreter_bin - - cwd = os.getcwd() - os.chdir(work_dir) - os.system('make clean') - os.system('make') - prog = resources['prog'] - if 'readable' in os.path.basename(work_dir): - prog = prog + '.ll' - else: - prog = prog + '.bc' - - prog = os.path.join(work_dir, prog) - execlist = [llvm_interpreter_bin, prog] - execlist.extend(resources['inputs'].split(' ')) - print(' '.join(execlist)) - p = subprocess.Popen(execlist) - p.wait() - os.chdir(cwd) - - if p.returncode != 0: - return ("FAIL: ERROR in running on lli: " + llvm_interpreter_bin + ", test: " + work_dir) - else: - return ("PASS") + cwd = os.getcwd() + os.chdir(work_dir) + os.system("make clean") + os.system("make") + prog = resources["prog"] + if "readable" in os.path.basename(work_dir): + prog = prog + ".ll" + else: + prog = prog + ".bc" + + prog = os.path.join(work_dir, prog) + execlist = [llvm_interpreter_bin, prog] + execlist.extend(resources["inputs"].split(" ")) + print(" ".join(execlist)) + p = subprocess.Popen(execlist) + p.wait() + os.chdir(cwd) + + if p.returncode != 0: + return ( + "FAIL: ERROR in running on lli: " + + llvm_interpreter_bin + + ", test: " + + work_dir + ) + else: + return "PASS" def test_generate_makefile(*test_list): - global generate_makefile_script - global llvm_interpreter_bin - - r = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - llfi_tools_dir = os.path.join(script_dir, os.pardir, os.pardir, 'tools') - generate_makefile_script = os.path.join(llfi_tools_dir, 'GenerateMakefile') - sys.path.append(os.path.join(script_dir, os.pardir, os.pardir, 'config')) - import llvm_paths - llvm_interpreter_bin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/lli") - - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 - - work_dict = {} - for test in suite["MakefileGeneration"]: - if len(test_list) == 0 or test in test_list or "all" in test_list: - work_dict["./MakefileGeneration/"+test] = suite["MakefileGeneration"][test] - - result_list = [] - for test_path in work_dict: - print ("MSG: Testing GenerateMakfile on:", test_path) - work_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) - result = callGenerateMakefile(work_dir, work_dict[test_path]) - if result != 'PASS': - r += 1 - else: - result = callLLVMInterpreter(work_dir, work_dict[test_path]) - if result != 'PASS': - r += 1 - record = {"name": test_path, "result": result} - result_list.append(record) - - return r, result_list + global generate_makefile_script + global llvm_interpreter_bin + + r = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + llfi_tools_dir = os.path.join(script_dir, os.pardir, os.pardir, "tools") + generate_makefile_script = os.path.join(llfi_tools_dir, "GenerateMakefile") + sys.path.append(os.path.join(script_dir, os.pardir, os.pardir, "config")) + import llvm_paths + + llvm_interpreter_bin = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/lli") + + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + work_dict = {} + for test in suite["MakefileGeneration"]: + if len(test_list) == 0 or test in test_list or "all" in test_list: + work_dict["./MakefileGeneration/" + test] = suite["MakefileGeneration"][ + test + ] + + result_list = [] + for test_path in work_dict: + print("MSG: Testing GenerateMakfile on:", test_path) + work_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) + result = callGenerateMakefile(work_dir, work_dict[test_path]) + if result != "PASS": + r += 1 + else: + result = callLLVMInterpreter(work_dir, work_dict[test_path]) + if result != "PASS": + r += 1 + record = {"name": test_path, "result": result} + result_list.append(record) + + return r, result_list + if __name__ == "__main__": - r, result_list = test_generate_makefile(*sys.argv[1:]) - print ("=============== Result ===============") - for record in result_list: - print(record["name"], "\t\t", record["result"]) + r, result_list = test_generate_makefile(*sys.argv[1:]) + print("=============== Result ===============") + for record in result_list: + print(record["name"], "\t\t", record["result"]) - sys.exit(r) \ No newline at end of file + sys.exit(r) diff --git a/test_suite/SCRIPTS/test_instruction_duplication.py b/test_suite/SCRIPTS/test_instruction_duplication.py new file mode 100644 index 00000000..0b618cb1 --- /dev/null +++ b/test_suite/SCRIPTS/test_instruction_duplication.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 + +""" +Tests for the Selective Instruction Duplication (SID) pass +(llvm_passes/instruction_duplication/InstructionDuplication.cpp). + +The pass is compiled to SEDPasses.so and invoked via opt with the new pass +manager (--passes=InstructionDuplicationPass). All tests write temporary LLVM +IR, run opt, and inspect the output. + +Tests are reported as SKIP when SEDPasses.so has not been built yet. + +Run standalone: python3 test_instruction_duplication.py +Via test driver: llfi_test --all_ml +""" + +import os +import re +import subprocess +import sys +import tempfile + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _find_build_dir(): + """Return the CMake build root (two levels above this script's directory).""" + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + + +def _find_source_root(): + cache_file = os.path.join(_find_build_dir(), "CMakeCache.txt") + if not os.path.isfile(cache_file): + return None + with open(cache_file) as f: + for line in f: + if line.startswith("Project_SOURCE_DIR:STATIC="): + return line.split("=", 1)[1].strip() + return None + + +def _find_opt(): + """Find the opt binary: prefer the LLVM install used by the build.""" + build_dir = _find_build_dir() + config_dir = os.path.join(build_dir, "config") + if os.path.isdir(config_dir) and config_dir not in sys.path: + sys.path.insert(0, config_dir) + try: + import llvm_paths + + candidate = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin", "opt") + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + except (ImportError, AttributeError): + pass + import shutil + + return shutil.which("opt") + + +def _find_sed_so(): + """Return path to SEDPasses.so in the build tree, or None.""" + build_dir = _find_build_dir() + candidate = os.path.join( + build_dir, "llvm_passes", "instruction_duplication", "SEDPasses.so" + ) + return candidate if os.path.isfile(candidate) else None + + +def _run_pass(opt, sed_so, ir_text, extra_flags=None, tmpdir=None): + """ + Write ir_text to a temp file, run InstructionDuplicationPass via opt, + and return (returncode, stdout_text). + + extra_flags is a list of additional cl::opt flags + (e.g. ['--enableChainDuplication']). + """ + ir_path = os.path.join(tmpdir, "input.ll") + out_path = os.path.join(tmpdir, "output.ll") + with open(ir_path, "w") as f: + f.write(ir_text) + + cmd = [ + opt, + "-load-pass-plugin", + sed_so, + "--passes=InstructionDuplicationPass", + "-S", + ir_path, + "-o", + out_path, + ] + if extra_flags: + cmd.extend(extra_flags) + + p = subprocess.run(cmd, capture_output=True, text=True) + output = "" + if os.path.isfile(out_path): + with open(out_path) as f: + output = f.read() + return p.returncode, output + + +# --------------------------------------------------------------------------- +# Minimal LLVM IR fixtures +# Note: 1986948931 == "conv" operator ID; 119251066446157 == "matmul". +# OMInstrumentPoint(id, 2) marks the start of an operator region; +# OMInstrumentPoint(id, 1) marks the end. +# --------------------------------------------------------------------------- + +_IR_INSIDE_BOUNDARY = """\ +; ModuleID = 'test_inside' +target triple = "x86_64-unknown-linux-gnu" + +declare void @OMInstrumentPoint(i64, i64) + +define float @main_graph(float %a, float %b) { +entry: + call void @OMInstrumentPoint(i64 1986948931, i64 2) + %r = fadd float %a, %b + call void @OMInstrumentPoint(i64 1986948931, i64 1) + ret float %r +} +""" + +_IR_OUTSIDE_BOUNDARY = """\ +; ModuleID = 'test_outside' +target triple = "x86_64-unknown-linux-gnu" + +define float @main_graph(float %a, float %b) { +entry: + %r = fadd float %a, %b + ret float %r +} +""" + +_IR_CHAIN = """\ +; ModuleID = 'test_chain' +target triple = "x86_64-unknown-linux-gnu" + +declare void @OMInstrumentPoint(i64, i64) + +define float @main_graph(float %a, float %b, float %c) { +entry: + call void @OMInstrumentPoint(i64 1986948931, i64 2) + %r1 = fadd float %a, %b + %r2 = fmul float %r1, %c + call void @OMInstrumentPoint(i64 1986948931, i64 1) + ret float %r2 +} +""" + +_IR_WRONG_OPERATOR = """\ +; ModuleID = 'test_filter' +target triple = "x86_64-unknown-linux-gnu" + +declare void @OMInstrumentPoint(i64, i64) + +define float @main_graph(float %a, float %b) { +entry: + call void @OMInstrumentPoint(i64 119251066446157, i64 2) + %r = fadd float %a, %b + call void @OMInstrumentPoint(i64 119251066446157, i64 1) + ret float %r +} +""" + +_IR_MULTIPLE_OPS = """\ +; ModuleID = 'test_multi' +target triple = "x86_64-unknown-linux-gnu" + +declare void @OMInstrumentPoint(i64, i64) + +define float @main_graph(float %a, float %b, float %c) { +entry: + call void @OMInstrumentPoint(i64 1986948931, i64 2) + %r1 = fadd float %a, %b + call void @OMInstrumentPoint(i64 1986948931, i64 1) + call void @OMInstrumentPoint(i64 119251066446157, i64 2) + %r2 = fmul float %r1, %c + call void @OMInstrumentPoint(i64 119251066446157, i64 1) + ret float %r2 +} +""" + + +# --------------------------------------------------------------------------- +# Individual test functions +# --------------------------------------------------------------------------- + + +def _test_smoke(opt, sed_so): + """Pass runs on valid IR without error.""" + prefix = "./instruction_duplication/smoke" + with tempfile.TemporaryDirectory() as tmpdir: + rc, _ = _run_pass(opt, sed_so, _IR_INSIDE_BOUNDARY, tmpdir=tmpdir) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + return [{"name": prefix, "result": "PASS"}] + + +def _test_instrumentation_inserted(opt, sed_so): + """compareFloatValues call is injected for an fadd inside an operator boundary.""" + prefix = "./instruction_duplication/instrumentation_inserted" + with tempfile.TemporaryDirectory() as tmpdir: + rc, output = _run_pass(opt, sed_so, _IR_INSIDE_BOUNDARY, tmpdir=tmpdir) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + if not re.search(r"call float @compareFloatValues", output): + return [ + { + "name": prefix, + "result": "FAIL: compareFloatValues call not found in output IR", + } + ] + return [{"name": prefix, "result": "PASS"}] + + +def _test_instruction_duplicated(opt, sed_so): + """The original arithmetic instruction is duplicated in the output IR.""" + prefix = "./instruction_duplication/instruction_duplicated" + with tempfile.TemporaryDirectory() as tmpdir: + rc, output = _run_pass(opt, sed_so, _IR_INSIDE_BOUNDARY, tmpdir=tmpdir) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + # Two fadd instructions should appear (original + duplicate) + fadds = re.findall(r"fadd float", output) + if len(fadds) < 2: + return [ + { + "name": prefix, + "result": f"FAIL: expected >= 2 fadd instructions, found {len(fadds)}", + } + ] + return [{"name": prefix, "result": "PASS"}] + + +def _test_no_duplication_outside_boundary(opt, sed_so): + """Arithmetic outside any OMInstrumentPoint boundary is not duplicated.""" + prefix = "./instruction_duplication/no_duplication_outside_boundary" + with tempfile.TemporaryDirectory() as tmpdir: + rc, output = _run_pass(opt, sed_so, _IR_OUTSIDE_BOUNDARY, tmpdir=tmpdir) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + if re.search(r"call float @compareFloatValues", output): + return [ + { + "name": prefix, + "result": "FAIL: compareFloatValues unexpectedly found in output IR", + } + ] + fadds = re.findall(r"fadd float", output) + if len(fadds) != 1: + return [ + { + "name": prefix, + "result": f"FAIL: expected exactly 1 fadd, found {len(fadds)}", + } + ] + return [{"name": prefix, "result": "PASS"}] + + +def _test_chain_duplication(opt, sed_so): + """With --enableChainDuplication, a consecutive arithmetic chain is duplicated.""" + prefix = "./instruction_duplication/chain_duplication" + with tempfile.TemporaryDirectory() as tmpdir: + rc, output = _run_pass( + opt, + sed_so, + _IR_CHAIN, + extra_flags=["--enableChainDuplication"], + tmpdir=tmpdir, + ) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + if not re.search(r"call float @compareFloatValues", output): + return [ + { + "name": prefix, + "result": "FAIL: compareFloatValues call not found in chain output IR", + } + ] + # Both fadd and fmul should be duplicated + fadds = re.findall(r"fadd float", output) + fmuls = re.findall(r"fmul float", output) + if len(fadds) < 2: + return [ + { + "name": prefix, + "result": f"FAIL: expected >= 2 fadd in chain, found {len(fadds)}", + } + ] + if len(fmuls) < 2: + return [ + { + "name": prefix, + "result": f"FAIL: expected >= 2 fmul in chain, found {len(fmuls)}", + } + ] + return [{"name": prefix, "result": "PASS"}] + + +def _test_operator_filtering(opt, sed_so): + """With --operatorName=conv, a matmul operator region is not instrumented.""" + prefix = "./instruction_duplication/operator_filtering" + with tempfile.TemporaryDirectory() as tmpdir: + rc, output = _run_pass( + opt, + sed_so, + _IR_WRONG_OPERATOR, + extra_flags=["--operatorName=conv"], + tmpdir=tmpdir, + ) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + if re.search(r"call float @compareFloatValues", output): + return [ + { + "name": prefix, + "result": "FAIL: compareFloatValues found despite operator mismatch", + } + ] + return [{"name": prefix, "result": "PASS"}] + + +def _test_multiple_operator_regions(opt, sed_so): + """With --operatorName=conv, only the conv region is instrumented, not matmul.""" + prefix = "./instruction_duplication/multiple_operator_regions" + with tempfile.TemporaryDirectory() as tmpdir: + rc, output = _run_pass( + opt, + sed_so, + _IR_MULTIPLE_OPS, + extra_flags=["--operatorName=conv"], + tmpdir=tmpdir, + ) + if rc != 0: + return [{"name": prefix, "result": f"FAIL: opt exited {rc}"}] + if not re.search(r"call float @compareFloatValues", output): + return [ + { + "name": prefix, + "result": "FAIL: compareFloatValues not found for conv region", + } + ] + # Only the fadd (conv region) should be duplicated, not the fmul (matmul region) + fmuls = re.findall(r"fmul float", output) + if len(fmuls) != 1: + return [ + { + "name": prefix, + "result": f"FAIL: fmul in matmul region was duplicated (found {len(fmuls)})", + } + ] + return [{"name": prefix, "result": "PASS"}] + + +# --------------------------------------------------------------------------- +# Top-level entry point +# --------------------------------------------------------------------------- + + +def _find_real_model_ll(): + """Return path to the pre-compiled mnist model.ll, or None if absent. + + model.ll is produced by running compile.sh in + sample_programs/ml_sample_programs/vision_models/mnist/ which requires + onnx-mlir. When absent the real-model tests are SKIPped. + """ + src = _find_source_root() + if src is None: + return None + candidate = os.path.join( + src, + "sample_programs", + "ml_sample_programs", + "vision_models", + "mnist", + "model.ll", + ) + return candidate if os.path.isfile(candidate) else None + + +def _find_sid_helper_ll(): + """Return path to SIDHelperFunctions.ll (produced by compile_shrd_lib.sh), + or None if it hasn't been built yet.""" + src = _find_source_root() + if src is None: + return None + candidate = os.path.join( + src, + "llvm_passes", + "instruction_duplication", + "shared_lib", + "SIDHelperFunctions.ll", + ) + return candidate if os.path.isfile(candidate) else None + + +def _find_llvm_tool(name): + """Locate an LLVM binary (e.g. llvm-link, lli) beside the opt binary.""" + opt = _find_opt() + if opt: + candidate = os.path.join(os.path.dirname(opt), name) + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + import shutil + + return shutil.which(name) + + +def test_instruction_duplication(): + """ + Run all InstructionDuplication pass tests. + + Returns (returncode, result_list) where returncode is 0 on success. + """ + result_list = [] + skip_msg = "SKIP: SEDPasses.so not found — build LLTFI first" + + sed_so = _find_sed_so() + if not sed_so: + for name in ( + "smoke", + "instrumentation_inserted", + "instruction_duplicated", + "no_duplication_outside_boundary", + "chain_duplication", + "operator_filtering", + "multiple_operator_regions", + ): + result_list.append( + { + "name": f"./instruction_duplication/{name}", + "result": skip_msg, + } + ) + return 0, result_list + + opt = _find_opt() + if not opt: + for name in ( + "smoke", + "instrumentation_inserted", + "instruction_duplicated", + "no_duplication_outside_boundary", + "chain_duplication", + "operator_filtering", + "multiple_operator_regions", + ): + result_list.append( + { + "name": f"./instruction_duplication/{name}", + "result": "SKIP: opt binary not found", + } + ) + return 0, result_list + + result_list.extend(_test_smoke(opt, sed_so)) + result_list.extend(_test_instrumentation_inserted(opt, sed_so)) + result_list.extend(_test_instruction_duplicated(opt, sed_so)) + result_list.extend(_test_no_duplication_outside_boundary(opt, sed_so)) + result_list.extend(_test_chain_duplication(opt, sed_so)) + result_list.extend(_test_operator_filtering(opt, sed_so)) + result_list.extend(_test_multiple_operator_regions(opt, sed_so)) + + # Real onnx-mlir IR tests (SKIP when model.ll is absent — requires onnx-mlir + # to have been run via compile.sh in sample_programs/.../mnist/). + model_ll = _find_real_model_ll() + if model_ll is None: + skip = ( + "SKIP: model.ll not found — run compile.sh in " + "sample_programs/ml_sample_programs/vision_models/mnist/" + ) + has_fail = any(r["result"].startswith("FAIL") for r in result_list) + return (1 if has_fail else 0), result_list + + +if __name__ == "__main__": + rc, results = test_instruction_duplication() + for r in results: + print(r["name"], "\t\t", r["result"]) + sys.exit(rc) diff --git a/test_suite/SCRIPTS/test_ml_models.py b/test_suite/SCRIPTS/test_ml_models.py new file mode 100644 index 00000000..7006fa50 --- /dev/null +++ b/test_suite/SCRIPTS/test_ml_models.py @@ -0,0 +1,549 @@ +#! /usr/bin/env python3 + +""" +Tests for the TensorFlow and PyTorch model compilation and fault injection +pipelines in LLTFI. + +Tests are organised in four tiers; each tier skips when its dependencies are +absent so the suite always finishes cleanly. + + Tier 1 — TensorFlow → ONNX + requires: tensorflow, tf2onnx, onnx + Tier 2 — PyTorch → ONNX + requires: torch, onnx + Tier 3 — ONNX → LLVM IR (onnx-mlir / mlir-translate) + requires: onnx-mlir binary, mlir-translate binary + uses: pre-built model.onnx from the mnist sample dir (if present) + OR a model produced by tier 1/2 + Tier 4 — Fault injection on compiled ML model + requires: LLTFI build, tier-3 output (model.ll + image input) + +Run with: python3 test_ml_models.py +Or via the test driver: llfi_test --all_ml +""" + +import os +import sys +import shutil +import subprocess +import tempfile + +# --------------------------------------------------------------------------- +# Helpers shared by all tiers +# --------------------------------------------------------------------------- + + +def _find_build_dir(): + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + + +def _find_source_root(): + cache_file = os.path.join(_find_build_dir(), "CMakeCache.txt") + if not os.path.isfile(cache_file): + return None + with open(cache_file) as f: + for line in f: + if line.startswith("Project_SOURCE_DIR:STATIC="): + return line.split("=", 1)[1].strip() + return None + + +def _has_dep(*modules): + for m in modules: + try: + __import__(m) + except ImportError: + return False + return True + + +def _find_binary(name): + """Search ONNX_MLIR_BUILD/bin, LLVM_DST_ROOT/bin, then PATH.""" + for env_var in ("ONNX_MLIR_BUILD", "LLVM_DST_ROOT"): + base = os.environ.get(env_var) + if base: + p = os.path.join(base, "bin", name) + if os.path.isfile(p) and os.access(p, os.X_OK): + return p + return shutil.which(name) + + +def _import_llvm_paths(): + config_dir = os.path.join(_find_build_dir(), "config") + if config_dir not in sys.path: + sys.path.insert(0, config_dir) + import llvm_paths + + return llvm_paths + + +# --------------------------------------------------------------------------- +# Tier 1: TensorFlow → ONNX +# --------------------------------------------------------------------------- + + +def _train_tf_model(saved_model_dir): + """Train a minimal single Dense-layer TF model on random data (1 epoch).""" + import numpy as np + import tensorflow as tf + + x = np.random.rand(32, 28, 28).astype("float32") + y = np.random.randint(0, 10, 32) + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(10, activation="softmax"), + ] + ) + model.compile( + optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + ) + model.fit(x, y, epochs=1, verbose=0) + model.save(saved_model_dir) + + +def test_tensorflow_pipeline(): + result_list = [] + prefix = "./ml/tensorflow" + + if not _has_dep("tensorflow", "tf2onnx", "onnx"): + for name in ("train", "tf_to_onnx", "onnx_valid"): + result_list.append( + { + "name": f"{prefix}/{name}", + "result": "SKIP: tensorflow, tf2onnx, or onnx not installed", + } + ) + return result_list + + with tempfile.TemporaryDirectory() as tmpdir: + saved_model_dir = os.path.join(tmpdir, "model.tf") + onnx_path = os.path.join(tmpdir, "model.onnx") + + # --- Test 1: model trains without error --- + try: + _train_tf_model(saved_model_dir) + if not os.path.isdir(saved_model_dir): + raise RuntimeError("SavedModel directory not created") + result_list.append({"name": f"{prefix}/train", "result": "PASS"}) + except Exception as e: + result_list.append({"name": f"{prefix}/train", "result": f"FAIL: {e}"}) + return result_list + + # --- Test 2: convert SavedModel → ONNX --- + p = subprocess.run( + [ + sys.executable, + "-m", + "tf2onnx.convert", + "--saved-model", + saved_model_dir, + "--output", + onnx_path, + ], + capture_output=True, + text=True, + ) + if p.returncode != 0 or not os.path.isfile(onnx_path): + result_list.append( + { + "name": f"{prefix}/tf_to_onnx", + "result": f"FAIL: tf2onnx exited {p.returncode}: {p.stderr.strip()[:200]}", + } + ) + return result_list + result_list.append({"name": f"{prefix}/tf_to_onnx", "result": "PASS"}) + + # --- Test 3: ONNX model is structurally valid --- + try: + import onnx + + model = onnx.load(onnx_path) + onnx.checker.check_model(model) + result_list.append({"name": f"{prefix}/onnx_valid", "result": "PASS"}) + except Exception as e: + result_list.append({"name": f"{prefix}/onnx_valid", "result": f"FAIL: {e}"}) + + return result_list + + +# --------------------------------------------------------------------------- +# Tier 2: PyTorch → ONNX +# --------------------------------------------------------------------------- + + +def _export_pytorch_model(onnx_path): + """Export a minimal untrained PyTorch model to ONNX.""" + import torch + import torch.nn as nn + + class TinyNet(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(784, 10) + + def forward(self, x): + return torch.softmax(self.fc(x), dim=1) + + model = TinyNet() + dummy = torch.randn(1, 784) + torch.onnx.export( + model, + dummy, + onnx_path, + input_names=["input"], + output_names=["output"], + opset_version=11, + verbose=False, + ) + + +def test_pytorch_pipeline(): + result_list = [] + prefix = "./ml/pytorch" + + if not _has_dep("torch", "onnx"): + for name in ("export_onnx", "onnx_valid"): + result_list.append( + { + "name": f"{prefix}/{name}", + "result": "SKIP: torch or onnx not installed", + } + ) + return result_list + + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = os.path.join(tmpdir, "model.onnx") + + # --- Test 1: export to ONNX --- + try: + _export_pytorch_model(onnx_path) + if not os.path.isfile(onnx_path): + raise RuntimeError("ONNX file not created") + result_list.append({"name": f"{prefix}/export_onnx", "result": "PASS"}) + except Exception as e: + result_list.append( + {"name": f"{prefix}/export_onnx", "result": f"FAIL: {e}"} + ) + return result_list + + # --- Test 2: ONNX model is structurally valid --- + try: + import onnx + + model = onnx.load(onnx_path) + onnx.checker.check_model(model) + result_list.append({"name": f"{prefix}/onnx_valid", "result": "PASS"}) + except Exception as e: + result_list.append({"name": f"{prefix}/onnx_valid", "result": f"FAIL: {e}"}) + + return result_list + + +# --------------------------------------------------------------------------- +# Tier 3: ONNX → LLVM IR via onnx-mlir + mlir-translate +# --------------------------------------------------------------------------- + + +def _find_prebuilt_onnx(source_root): + """Return path to pre-built model.onnx in the mnist sample dir, or None.""" + candidate = os.path.join( + source_root, + "sample_programs", + "ml_sample_programs", + "vision_models", + "mnist", + "model.onnx", + ) + return candidate if os.path.isfile(candidate) else None + + +def test_onnx_to_ir(onnx_path=None): + """ + Compile an ONNX model to LLVM IR using onnx-mlir and mlir-translate. + + If onnx_path is None, tries to find the pre-built model.onnx from the + mnist sample directory. Skips if onnx-mlir or mlir-translate are not on + PATH or in ONNX_MLIR_BUILD/bin. + """ + result_list = [] + prefix = "./ml/onnx_to_ir" + + onnx_mlir = _find_binary("onnx-mlir") + mlir_translate = _find_binary("mlir-translate") + + if not onnx_mlir or not mlir_translate: + missing = [] + if not onnx_mlir: + missing.append("onnx-mlir") + if not mlir_translate: + missing.append("mlir-translate") + for name in ("compile_mlir", "translate_to_ll"): + result_list.append( + { + "name": f"{prefix}/{name}", + "result": f'SKIP: {", ".join(missing)} not found ' + f"(set ONNX_MLIR_BUILD or add to PATH)", + } + ) + return result_list + + source_root = _find_source_root() + if onnx_path is None and source_root: + onnx_path = _find_prebuilt_onnx(source_root) + + if onnx_path is None or not os.path.isfile(onnx_path): + for name in ("compile_mlir", "translate_to_ll"): + result_list.append( + { + "name": f"{prefix}/{name}", + "result": "SKIP: no ONNX model available " + "(run compile.sh in sample_programs/ml_sample_programs/" + "vision_models/mnist/ first, or install tensorflow/torch)", + } + ) + return result_list + + with tempfile.TemporaryDirectory() as tmpdir: + # Copy ONNX to tmpdir so onnx-mlir outputs land there cleanly + local_onnx = os.path.join(tmpdir, "model.onnx") + shutil.copy2(onnx_path, local_onnx) + mlir_out = os.path.join(tmpdir, "model.onnx.mlir") + ll_out = os.path.join(tmpdir, "model.mlir.ll") + + # --- Test 1: onnx-mlir compiles ONNX → MLIR --- + p = subprocess.run( + [onnx_mlir, "--EmitLLVMIR", local_onnx], + capture_output=True, + text=True, + cwd=tmpdir, + ) + if p.returncode != 0 or not os.path.isfile(mlir_out): + result_list.append( + { + "name": f"{prefix}/compile_mlir", + "result": f"FAIL: onnx-mlir exited {p.returncode}: {p.stderr.strip()[:200]}", + } + ) + return result_list + result_list.append({"name": f"{prefix}/compile_mlir", "result": "PASS"}) + + # --- Test 2: mlir-translate produces LLVM IR --- + with open(ll_out, "w") as ll_file: + p = subprocess.run( + [mlir_translate, "-mlir-to-llvmir", mlir_out], + stdout=ll_file, + stderr=subprocess.PIPE, + text=True, + cwd=tmpdir, + ) + if ( + p.returncode != 0 + or not os.path.isfile(ll_out) + or os.path.getsize(ll_out) == 0 + ): + result_list.append( + { + "name": f"{prefix}/translate_to_ll", + "result": f"FAIL: mlir-translate exited {p.returncode}: {p.stderr.strip()[:200]}", + } + ) + else: + result_list.append({"name": f"{prefix}/translate_to_ll", "result": "PASS"}) + + return result_list + + +# --------------------------------------------------------------------------- +# Tier 4: Fault injection on a compiled ML model +# --------------------------------------------------------------------------- + + +def test_fault_injection(): + """ + Smoke-test the LLTFI fault injection pipeline on a compiled ML model. + + Requires: + - model.ll and eight.png in sample_programs/.../mnist/ + (produced by compile.sh — run that first if this test SKIPs) + - LLTFI build with instrument, profile, injectfault scripts + """ + result_list = [] + prefix = "./ml/fault_injection" + + source_root = _find_source_root() + if source_root is None: + result_list.append( + { + "name": f"{prefix}/instrument", + "result": "FAIL: cannot determine source root from CMakeCache.txt", + } + ) + return result_list + + mnist_dir = os.path.join( + source_root, "sample_programs", "ml_sample_programs", "vision_models", "mnist" + ) + model_ll = os.path.join(mnist_dir, "model.ll") + image_png = os.path.join(mnist_dir, "eight.png") + op_seq = os.path.join(mnist_dir, "expected_op_seq.txt") + + if not os.path.isfile(model_ll): + for name in ("instrument", "profile", "inject"): + result_list.append( + { + "name": f"{prefix}/{name}", + "result": "SKIP: model.ll not found — run compile.sh in " + "sample_programs/ml_sample_programs/vision_models/mnist/ first", + } + ) + return result_list + + try: + _import_llvm_paths() + except ImportError as e: + result_list.append( + { + "name": f"{prefix}/instrument", + "result": f"FAIL: cannot import llvm_paths — {e}", + } + ) + return result_list + + build_dir = _find_build_dir() + instrument_bin = os.path.join(build_dir, "bin", "instrument") + profile_bin = os.path.join(build_dir, "bin", "profile") + inject_bin = os.path.join(build_dir, "bin", "injectfault") + + for label, path in [ + ("instrument", instrument_bin), + ("profile", profile_bin), + ("injectfault", inject_bin), + ]: + if not os.path.isfile(path): + result_list.append( + { + "name": f"{prefix}/instrument", + "result": f"FAIL: {label} not found at {path}", + } + ) + return result_list + + with tempfile.TemporaryDirectory() as tmpdir: + # Copy model and support files into a clean workdir so llfi/ output + # lands in tmpdir and does not pollute the source tree. + local_ll = os.path.join(tmpdir, "model.ll") + local_img = os.path.join(tmpdir, "eight.png") + shutil.copy2(model_ll, local_ll) + if os.path.isfile(image_png): + shutil.copy2(image_png, local_img) + # Copy input.yaml so instrument knows FI config + local_yaml = os.path.join(tmpdir, "input.yaml") + shutil.copy2(os.path.join(mnist_dir, "input.yaml"), local_yaml) + + # Read expected_op_seq if available (second arg to profile/inject) + extra_arg = "" + if os.path.isfile(op_seq): + with open(op_seq) as f: + extra_arg = f.read().strip() + + # --- Test 1: instrument --- + p = subprocess.run( + [instrument_bin, "--readable", "-L", "$ONNX_MLIR_BUILD/Debug/lib", "-lcruntime", "-ljson-c", "-lprotobuf", local_ll], + capture_output=True, + text=True, + cwd=tmpdir, + ) + llfi_dir = os.path.join(tmpdir, "llfi") + if p.returncode != 0 or not os.path.isdir(llfi_dir): + result_list.append( + { + "name": f"{prefix}/instrument", + "result": f"FAIL: instrument exited {p.returncode}: {p.stderr.strip()[:200]}", + } + ) + return result_list + result_list.append({"name": f"{prefix}/instrument", "result": "PASS"}) + + # --- Test 2: profile --- + profiling_exe = os.path.join(llfi_dir, "model-profiling.exe") + if not os.path.isfile(profiling_exe): + result_list.append( + { + "name": f"{prefix}/profile", + "result": f"FAIL: profiling exe not found at {profiling_exe}", + } + ) + return result_list + + profile_cmd = [profile_bin, profiling_exe] + if os.path.isfile(local_img): + profile_cmd.append(local_img) + if extra_arg: + profile_cmd.append(extra_arg) + p = subprocess.run(profile_cmd, capture_output=True, text=True, cwd=tmpdir) + if p.returncode != 0: + result_list.append( + { + "name": f"{prefix}/profile", + "result": f"FAIL: profile exited {p.returncode}: {p.stderr.strip()[:200]}", + } + ) + return result_list + result_list.append({"name": f"{prefix}/profile", "result": "PASS"}) + + # --- Test 3: inject (1 run to verify the pipeline works) --- + inject_exe = os.path.join(llfi_dir, "model-faultinjection.exe") + if not os.path.isfile(inject_exe): + result_list.append( + { + "name": f"{prefix}/inject", + "result": f"FAIL: fault injection exe not found at {inject_exe}", + } + ) + return result_list + + inject_cmd = [inject_bin, inject_exe] + if os.path.isfile(local_img): + inject_cmd.append(local_img) + if extra_arg: + inject_cmd.append(extra_arg) + p = subprocess.run(inject_cmd, capture_output=True, text=True, cwd=tmpdir) + # A non-zero exit from the injected binary is normal (crash = fault effect) + stat_dir = os.path.join(llfi_dir, "llfi_stat_output") + if not os.path.isdir(stat_dir): + result_list.append( + { + "name": f"{prefix}/inject", + "result": "FAIL: no llfi_stat_output directory produced", + } + ) + else: + result_list.append({"name": f"{prefix}/inject", "result": "PASS"}) + + return result_list + + +# --------------------------------------------------------------------------- +# Top-level entry point +# --------------------------------------------------------------------------- + + +def test_ml_models(): + """Run all four tiers and return a combined result list.""" + all_results = [] + all_results.extend(test_tensorflow_pipeline()) + all_results.extend(test_pytorch_pipeline()) + all_results.extend(test_onnx_to_ir()) + all_results.extend(test_fault_injection()) + + has_fail = any(r["result"].startswith("FAIL") for r in all_results) + return (1 if has_fail else 0), all_results + + +if __name__ == "__main__": + rc, results = test_ml_models() + for r in results: + print(r["name"], "\t\t", r["result"]) + sys.exit(rc) diff --git a/test_suite/SCRIPTS/test_ml_tools.py b/test_suite/SCRIPTS/test_ml_tools.py new file mode 100644 index 00000000..ee75be6b --- /dev/null +++ b/test_suite/SCRIPTS/test_ml_tools.py @@ -0,0 +1,237 @@ +#! /usr/bin/env python3 + +""" +Tests for the ONNX/ML analysis tools in tools/: + - CompareLayerOutputs.py (requires: onnx, pygraphviz) + - ExtendONNXModel.py (requires: onnx) + - outputONNXGraph.py (requires: onnx, pydot) + +Tests that require a missing dependency are reported as SKIP, not FAIL. +""" + +import os +import sys +import json +import subprocess +import tempfile + + +def _find_build_dir(): + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + + +def _find_source_root(): + cache_file = os.path.join(_find_build_dir(), "CMakeCache.txt") + if not os.path.isfile(cache_file): + return None + with open(cache_file) as f: + for line in f: + if line.startswith("Project_SOURCE_DIR:STATIC="): + return line.split("=", 1)[1].strip() + return None + + +def _has_dep(*modules): + for m in modules: + try: + __import__(m) + except ImportError: + return False + return True + + +def _make_minimal_onnx(path): + """Create a minimal single-Relu ONNX model and save to path.""" + import onnx + from onnx import helper, TensorProto + + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3]) + relu = helper.make_node("Relu", inputs=["X"], outputs=["Y"]) + graph = helper.make_graph([relu], "test_graph", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + onnx.save(model, path) + + +def _make_layer_json(layers_data): + """Build the JSON format expected by CompareLayerOutputs.py.""" + result = {} + for i, values in enumerate(layers_data): + result[str(i)] = { + "Layer Id": str(i), + "Rank": 1, + "Number of Elements": len(values), + "Shape": [len(values)], + "Data": values, + } + return result + + +def test_ml_tools(): + result_list = [] + + source_root = _find_source_root() + if source_root is None: + result_list.append( + { + "name": "./ml_tools/setup", + "result": "FAIL: cannot determine source root from CMakeCache.txt", + } + ) + return 1, result_list + + compare_script = os.path.join(source_root, "tools", "CompareLayerOutputs.py") + extend_script = os.path.join(source_root, "tools", "ExtendONNXModel.py") + graph_script = os.path.join(source_root, "tools", "outputONNXGraph.py") + + # ----------------------------------------------------------------------- + # CompareLayerOutputs — requires onnx + pygraphviz to import the module + # ----------------------------------------------------------------------- + if not _has_dep("onnx", "pygraphviz"): + for name in ("compare_layer_outputs_match", "compare_layer_outputs_diff"): + result_list.append( + { + "name": f"./ml_tools/{name}", + "result": "SKIP: onnx or pygraphviz not installed", + } + ) + else: + with tempfile.TemporaryDirectory() as tmpdir: + golden_json = os.path.join(tmpdir, "golden.json") + same_json = os.path.join(tmpdir, "faulty_same.json") + diff_json = os.path.join(tmpdir, "faulty_diff.json") + dummy_onnx = os.path.join(tmpdir, "dummy.onnx") + + _make_minimal_onnx(dummy_onnx) + + golden_data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + same_data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + diff_data = [[1.0, 2.0, 3.0], [4.0, 5.0, 99.0]] # last value changed + + for path, data in [ + (golden_json, golden_data), + (same_json, same_data), + (diff_json, diff_data), + ]: + with open(path, "w") as f: + json.dump(_make_layer_json(data), f) + + # Test: identical data → no mismatch + p = subprocess.run( + ["python3", compare_script, golden_json, same_json, dummy_onnx], + capture_output=True, + text=True, + ) + if p.returncode != 0 or "No mismatch found" not in p.stdout: + result_list.append( + { + "name": "./ml_tools/compare_layer_outputs_match", + "result": f'FAIL: expected "No mismatch found", got: {p.stdout.strip()[:120]}', + } + ) + else: + result_list.append( + {"name": "./ml_tools/compare_layer_outputs_match", "result": "PASS"} + ) + + # Test: differing data → mismatch detected + p = subprocess.run( + ["python3", compare_script, golden_json, diff_json, dummy_onnx], + capture_output=True, + text=True, + ) + if p.returncode != 0 or "No mismatch found" in p.stdout: + result_list.append( + { + "name": "./ml_tools/compare_layer_outputs_diff", + "result": f"FAIL: expected mismatch detection, got: {p.stdout.strip()[:120]}", + } + ) + else: + result_list.append( + {"name": "./ml_tools/compare_layer_outputs_diff", "result": "PASS"} + ) + + # ----------------------------------------------------------------------- + # ExtendONNXModel — requires onnx + # ----------------------------------------------------------------------- + if not _has_dep("onnx"): + result_list.append( + { + "name": "./ml_tools/extend_onnx_model", + "result": "SKIP: onnx not installed", + } + ) + else: + with tempfile.TemporaryDirectory() as tmpdir: + model_in = os.path.join(tmpdir, "model.onnx") + model_out = os.path.join(tmpdir, "extended.onnx") + _make_minimal_onnx(model_in) + + p = subprocess.run( + [ + "python3", + extend_script, + "--model_path", + model_in, + "--output_model_path", + model_out, + "--layers", + "all", + ], + capture_output=True, + text=True, + ) + if p.returncode != 0 or not os.path.isfile(model_out): + result_list.append( + { + "name": "./ml_tools/extend_onnx_model", + "result": f"FAIL: exited {p.returncode}: {p.stderr.strip()[:120]}", + } + ) + else: + result_list.append( + {"name": "./ml_tools/extend_onnx_model", "result": "PASS"} + ) + + # ----------------------------------------------------------------------- + # outputONNXGraph — requires onnx + pydot + # ----------------------------------------------------------------------- + if not _has_dep("onnx", "pydot"): + result_list.append( + { + "name": "./ml_tools/output_onnx_graph", + "result": "SKIP: onnx or pydot not installed", + } + ) + else: + with tempfile.TemporaryDirectory() as tmpdir: + model_path = os.path.join(tmpdir, "model.onnx") + dot_path = os.path.join(tmpdir, "model.dot") + _make_minimal_onnx(model_path) + + p = subprocess.run( + ["python3", graph_script, model_path, dot_path], + capture_output=True, + text=True, + ) + if p.returncode != 0 or not os.path.isfile(dot_path): + result_list.append( + { + "name": "./ml_tools/output_onnx_graph", + "result": f"FAIL: exited {p.returncode}: {p.stderr.strip()[:120]}", + } + ) + else: + result_list.append( + {"name": "./ml_tools/output_onnx_graph", "result": "PASS"} + ) + + return 0, result_list + + +if __name__ == "__main__": + rc, results = test_ml_tools() + for r in results: + print(r["name"], "\t\t", r["result"]) + sys.exit(rc) diff --git a/test_suite/SCRIPTS/test_trace_tools.py b/test_suite/SCRIPTS/test_trace_tools.py index 4e966bce..a6193ff7 100644 --- a/test_suite/SCRIPTS/test_trace_tools.py +++ b/test_suite/SCRIPTS/test_trace_tools.py @@ -2,7 +2,6 @@ import os import sys -import shutil import yaml import subprocess @@ -11,135 +10,161 @@ traceontograph_script = "" tracetodot_script = "" -def callTraceTools(work_dir, resources): - global tracediff_script - global traceunion_script - global traceontograph_script - global tracetodot_script - - golden_trace = resources['trace_prof'] - golden_trace_file = os.path.join(work_dir, golden_trace) - if os.path.isfile(golden_trace_file) == False: - return ("FAIL: golden_trace_file not found:", golden_trace) - ## call tracediff to generate the trace report file - reports_list = [] - for faulty_trace in resources["trace_inject"]: - faulty_trace_file = os.path.join(work_dir, faulty_trace) - if os.path.isfile(faulty_trace_file) == False: - print ("WARNING: faulty_trace_file not found:", faulty_trace, "work_dir:", work_dir) - pass - else: - report_name = '.'.join(faulty_trace.split('.')[0:-1])+'.report.'+faulty_trace.split('.')[-1] - report_file = os.path.join(work_dir, report_name) - commands = [tracediff_script, golden_trace_file, faulty_trace_file, '>', report_file] - p = subprocess.Popen(' '.join(commands), shell=True) - p.wait() - if p.returncode != 0: - return ("FAIL: \'tracediff\' quits unnormally!") - if os.path.isfile(report_file) == False: - return ("FAIL: report_file not generated by \'tracediff\':", report_name) - if os.path.getsize(report_file) == 0: - return ("FAIL: report_file generated by \'tracediff\' is empty:", report_name) - reports_list.append(report_file) - - ## call traceunion to generate a union of all reports - united_report_name = 'llfi.united.trace.report.txt' - united_report_file = os.path.join(work_dir, united_report_name) - commands = [traceunion_script] - commands.extend(reports_list) - commands.extend(['>', united_report_file]) - p = subprocess.Popen(' '.join(commands), shell=True) - p.wait() - if p.returncode != 0: - return ("FAIL: \'traceunion\' quits unnormally!") - if os.path.isfile(report_file) == False: - return ("FAIL: united report_file not generated by \'traceunion\':", united_report_name) - if os.path.getsize(report_file) == 0: - return ("FAIL: united report_file generated by \'traceunion\' is empty:", united_report_name) - - ## call traceontograph to generate the dot file - cdfg_prof_file = os.path.join(work_dir, resources['cdfg_prof']) - if os.path.isfile(cdfg_prof_file) == False: - return ("FAIL: cdfg_prof_file not found:", resources['cdfg_prof']) - commands = [traceontograph_script, united_report_file, cdfg_prof_file, '>'] - cdfg_faulty_name = 'llfi.faulty.graph.dot' - cdfg_faulty_file = os.path.join(work_dir, cdfg_faulty_name) - commands.append(cdfg_faulty_file) - p = subprocess.Popen(' '.join(commands), shell=True) - p.wait() - if p.returncode != 0: - return ("FAIL: \'traceontograph\' quits unnormally!") - if os.path.isfile(report_file) == False: - return ("FAIL: cdfg_faulty_file not generated by \'traceontograph\':", cdfg_faulty_name) - if os.path.getsize(report_file) == 0: - return ("FAIL: cdfg_faulty_file generated by \'traceontograph\' is empty:", cdfg_faulty_name) - - ## call tracetodot to generate all dot and report files - current_dir = os.path.abspath(os.path.curdir) - llfi_stat_dir = os.path.join(work_dir, 'llfi', 'llfi_stat_output') - os.chdir(llfi_stat_dir) - #print (os.getcwd()) - p = subprocess.Popen(tracetodot_script, shell=True) - p.wait() - os.chdir(current_dir) - if p.returncode != 0: - return ("FAIL: \'tracetodot\' quits unnormally!") - trace_dir = os.path.join(work_dir, 'llfi', 'trace_report_output') - #print(trace_dir) - if os.path.isdir(trace_dir) == False: - return ("FAIL: trace_report_output/ not generated by\'tracetodot\'!", work_dir) - t = [f for f in os.listdir(trace_dir)] - if len(t) < 2: - return ("FAIL: dot/report files generated by\'tracetodot\' not complete!", work_dir) - - return "PASS" +def callTraceTools(work_dir, resources): + golden_trace = resources["trace_prof"] + golden_trace_file = os.path.join(work_dir, golden_trace) + if not os.path.isfile(golden_trace_file): + return ("FAIL: golden_trace_file not found:", golden_trace) + ## call tracediff to generate the trace report file + reports_list = [] + for faulty_trace in resources["trace_inject"]: + faulty_trace_file = os.path.join(work_dir, faulty_trace) + if not os.path.isfile(faulty_trace_file): + print( + "WARNING: faulty_trace_file not found:", + faulty_trace, + "work_dir:", + work_dir, + ) + pass + else: + report_name = ( + ".".join(faulty_trace.split(".")[0:-1]) + + ".report." + + faulty_trace.split(".")[-1] + ) + report_file = os.path.join(work_dir, report_name) + with open(report_file, "w") as report_out: + p = subprocess.Popen( + [tracediff_script, golden_trace_file, faulty_trace_file], + stdout=report_out, + ) + p.wait() + if p.returncode != 0: + return "FAIL: 'tracediff' quits unnormally!" + if not os.path.isfile(report_file): + return ("FAIL: report_file not generated by 'tracediff':", report_name) + if os.path.getsize(report_file) == 0: + return ( + "FAIL: report_file generated by 'tracediff' is empty:", + report_name, + ) + reports_list.append(report_file) + + ## call traceunion to generate a union of all reports + united_report_name = "llfi.united.trace.report.txt" + united_report_file = os.path.join(work_dir, united_report_name) + commands = [traceunion_script] + reports_list + with open(united_report_file, "w") as united_out: + p = subprocess.Popen(commands, stdout=united_out) + p.wait() + if p.returncode != 0: + return "FAIL: 'traceunion' quits unnormally!" + if not os.path.isfile(report_file): + return ( + "FAIL: united report_file not generated by 'traceunion':", + united_report_name, + ) + if os.path.getsize(report_file) == 0: + return ( + "FAIL: united report_file generated by 'traceunion' is empty:", + united_report_name, + ) + + ## call traceontograph to generate the dot file + cdfg_prof_file = os.path.join(work_dir, resources["cdfg_prof"]) + if not os.path.isfile(cdfg_prof_file): + return ("FAIL: cdfg_prof_file not found:", resources["cdfg_prof"]) + cdfg_faulty_name = "llfi.faulty.graph.dot" + cdfg_faulty_file = os.path.join(work_dir, cdfg_faulty_name) + with open(cdfg_faulty_file, "w") as cdfg_out: + p = subprocess.Popen( + [traceontograph_script, united_report_file, cdfg_prof_file], stdout=cdfg_out + ) + p.wait() + if p.returncode != 0: + return "FAIL: 'traceontograph' quits unnormally!" + if not os.path.isfile(report_file): + return ( + "FAIL: cdfg_faulty_file not generated by 'traceontograph':", + cdfg_faulty_name, + ) + if os.path.getsize(report_file) == 0: + return ( + "FAIL: cdfg_faulty_file generated by 'traceontograph' is empty:", + cdfg_faulty_name, + ) + + ## call tracetodot to generate all dot and report files + current_dir = os.path.abspath(os.path.curdir) + llfi_stat_dir = os.path.join(work_dir, "llfi", "llfi_stat_output") + os.chdir(llfi_stat_dir) + # print (os.getcwd()) + p = subprocess.Popen([tracetodot_script]) + p.wait() + os.chdir(current_dir) + if p.returncode != 0: + return "FAIL: 'tracetodot' quits unnormally!" + trace_dir = os.path.join(work_dir, "llfi", "trace_report_output") + # print(trace_dir) + if not os.path.isdir(trace_dir): + return ("FAIL: trace_report_output/ not generated by'tracetodot'!", work_dir) + t = [f for f in os.listdir(trace_dir)] + if len(t) < 2: + return ( + "FAIL: dot/report files generated by'tracetodot' not complete!", + work_dir, + ) + + return "PASS" def test_trace_tools(*test_list): - global tracediff_script - global traceunion_script - global traceontograph_script - global tracetodot_script - - r = 0 - suite = {} - script_dir = os.path.dirname(os.path.realpath(__file__)) - llfi_tools_dir = os.path.join(script_dir, '../../tools') - tracediff_script = os.path.join(llfi_tools_dir, "tracediff") - traceunion_script = os.path.join(llfi_tools_dir, "traceunion") - traceontograph_script = os.path.join(llfi_tools_dir, "traceontograph") - tracetodot_script = os.path.join(llfi_tools_dir, "tracetodot") - - testsuite_dir = os.path.join(script_dir, os.pardir) - with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: - try: - suite = yaml.safe_load(f) - except: - print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) - return -1 - - work_dict = {} - for test in suite["Traces"]: - if len(test_list) == 0 or test in test_list or "all" in test_list: - work_dict["./Traces/"+test] = suite["Traces"][test] - - result_list = [] - for test_path in work_dict: - print ("MSG: Testing on trace files of:", test_path) - work_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) - result = callTraceTools(work_dir, work_dict[test_path]) - if result != 'PASS': - r += 1 - record = {"name": test_path, "result": result} - result_list.append(record) - - return r, result_list + global tracediff_script + global traceunion_script + global traceontograph_script + global tracetodot_script + + r = 0 + suite = {} + script_dir = os.path.dirname(os.path.realpath(__file__)) + llfi_tools_dir = os.path.join(script_dir, "../../tools") + tracediff_script = os.path.join(llfi_tools_dir, "tracediff") + traceunion_script = os.path.join(llfi_tools_dir, "traceunion") + traceontograph_script = os.path.join(llfi_tools_dir, "traceontograph") + tracetodot_script = os.path.join(llfi_tools_dir, "tracetodot") + + testsuite_dir = os.path.join(script_dir, os.pardir) + with open(os.path.join(testsuite_dir, "test_suite.yaml")) as f: + try: + suite = yaml.safe_load(f) + except Exception: + print("ERROR: Unable to load yaml file: test_suite.yaml", file=sys.stderr) + return -1 + + work_dict = {} + for test in suite["Traces"]: + if len(test_list) == 0 or test in test_list or "all" in test_list: + work_dict["./Traces/" + test] = suite["Traces"][test] + + result_list = [] + for test_path in work_dict: + print("MSG: Testing on trace files of:", test_path) + work_dir = os.path.abspath(os.path.join(testsuite_dir, test_path)) + result = callTraceTools(work_dir, work_dict[test_path]) + if result != "PASS": + r += 1 + record = {"name": test_path, "result": result} + result_list.append(record) + + return r, result_list + if __name__ == "__main__": - r, result_list = test_trace_tools(*sys.argv[1:]) - print ("=============== Result ===============") - for record in result_list: - print(record["name"], "\t\t", record["result"]) + r, result_list = test_trace_tools(*sys.argv[1:]) + print("=============== Result ===============") + for record in result_list: + print(record["name"], "\t\t", record["result"]) - sys.exit(r) + sys.exit(r) diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/input.yaml b/test_suite/Traces/BufferOverflowMemmove_Data/input.yaml deleted file mode 100644 index 7dcf633a..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/input.yaml +++ /dev/null @@ -1,27 +0,0 @@ -defaultTimeOut: 100 - -compileOption: - instSelMethod: - - customInstselector: - include: - - BufferOverflowMemmove(Data) - - regSelMethod: customregselector - customRegSelector: Automatic - - tracingPropagation: True # trace dynamic instruction values. - - tracingPropagationOption: - maxTrace: 250 # max number of instructions to trace during fault injection run - debugTrace: False - generateCDFG: True - -runOption: - - run: - numOfRuns: 5 - fi_type: AutoInjection - - - - - diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.config.compiletime.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi.config.compiletime.txt deleted file mode 100644 index 4ee7447e..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.config.compiletime.txt +++ /dev/null @@ -1,4 +0,0 @@ -failure_class=Data -failure_mode=BufferOverflow -targets=memcpy()/memmove() -injector=ChangeValueInjector diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.config.rumtime.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi.config.rumtime.txt deleted file mode 100644 index 8cfe76b8..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.config.rumtime.txt +++ /dev/null @@ -1,2 +0,0 @@ -fi_cycle=0 -fi_type=BufferOverflowMemmove(Data) diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.log.compilation.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi.log.compilation.txt deleted file mode 100644 index 09bba38b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.log.compilation.txt +++ /dev/null @@ -1,6 +0,0 @@ - - -Start of a pass - - -Start of a pass diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.graph.dot b/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.graph.dot deleted file mode 100644 index 00bfe0d5..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.graph.dot +++ /dev/null @@ -1,311 +0,0 @@ -digraph "LLFI Program Graph" { -llfiID_0 -> llfiID_90 [color="blue"]; -llfiID_0 -> llfiID_88 [color="blue"]; -llfiID_0 -> llfiID_27 [color="blue"]; -llfiID_0 -> llfiID_17 [color="blue"]; -llfiID_0 -> llfiID_7 [color="blue"]; -llfiID_1 -> llfiID_56 [color="blue"]; -llfiID_1 -> llfiID_37 [color="blue"]; -llfiID_1 -> llfiID_19 [color="blue"]; -llfiID_1 -> llfiID_9 [color="blue"]; -llfiID_1 -> llfiID_8 [color="blue"]; -llfiID_2 -> llfiID_84 [color="blue"]; -llfiID_2 -> llfiID_46 [color="blue"]; -llfiID_2 -> llfiID_43 [color="blue"]; -llfiID_2 -> llfiID_36 [color="blue"]; -llfiID_2 -> llfiID_13 [color="blue"]; -llfiID_2 -> llfiID_12 [color="blue"]; -llfiID_3 -> llfiID_86 [color="blue"]; -llfiID_3 -> llfiID_55 [color="blue"]; -llfiID_3 -> llfiID_48 [color="blue"]; -llfiID_3 -> llfiID_45 [color="blue"]; -llfiID_3 -> llfiID_23 [color="blue"]; -llfiID_3 -> llfiID_22 [color="blue"]; -llfiID_4 -> llfiID_82 [color="blue"]; -llfiID_4 -> llfiID_77 [color="blue"]; -llfiID_4 -> llfiID_65 [color="blue"]; -llfiID_4 -> llfiID_61 [color="blue"]; -llfiID_4 -> llfiID_58 [color="blue"]; -llfiID_4 -> llfiID_52 [color="blue"]; -llfiID_4 -> llfiID_51 [color="blue"]; -llfiID_4 -> llfiID_41 [color="blue"]; -llfiID_4 -> llfiID_39 [color="blue"]; -llfiID_4 -> llfiID_31 [color="blue"]; -llfiID_4 -> llfiID_30 [color="blue"]; -llfiID_6 -> llfiID_80 [color="blue"]; -llfiID_6 -> llfiID_74 [color="blue"]; -llfiID_6 -> llfiID_70 [color="blue"]; -llfiID_6 -> llfiID_68 [color="blue"]; -llfiID_9 -> llfiID_10 [color="blue"]; -llfiID_10 -> llfiID_11 [color="blue"]; -llfiID_11 -> llfiID_12 [color="blue"]; -llfiID_13 -> llfiID_14 [color="blue"]; -llfiID_14 -> llfiID_15 [color="blue"]; -llfiID_19 -> llfiID_20 [color="blue"]; -llfiID_20 -> llfiID_21 [color="blue"]; -llfiID_21 -> llfiID_22 [color="blue"]; -llfiID_23 -> llfiID_24 [color="blue"]; -llfiID_24 -> llfiID_25 [color="blue"]; -llfiID_29 -> llfiID_30 [color="blue"]; -llfiID_31 -> llfiID_32 [color="blue"]; -llfiID_32 -> llfiID_33 [color="blue"]; -llfiID_36 -> llfiID_40 [color="blue"]; -llfiID_37 -> llfiID_38 [color="blue"]; -llfiID_38 -> llfiID_40 [color="blue"]; -llfiID_39 -> llfiID_40 [color="blue"]; -llfiID_41 -> llfiID_42 [color="blue"]; -llfiID_43 -> llfiID_44 [color="blue"]; -llfiID_45 -> llfiID_47 [color="blue"]; -llfiID_46 -> llfiID_47 [color="blue"]; -llfiID_48 -> llfiID_49 [color="blue"]; -llfiID_50 -> llfiID_51 [color="blue"]; -llfiID_52 -> llfiID_53 [color="blue"]; -llfiID_53 -> llfiID_54 [color="blue"]; -llfiID_55 -> llfiID_59 [color="blue"]; -llfiID_56 -> llfiID_57 [color="blue"]; -llfiID_57 -> llfiID_59 [color="blue"]; -llfiID_58 -> llfiID_59 [color="blue"]; -llfiID_61 -> llfiID_62 [color="blue"]; -llfiID_65 -> llfiID_66 [color="blue"]; -llfiID_66 -> llfiID_67 [color="blue"]; -llfiID_67 -> llfiID_68 [color="blue"]; -llfiID_70 -> llfiID_71 [color="blue"]; -llfiID_71 -> llfiID_72 [color="blue"]; -llfiID_72 -> llfiID_73 [color="blue"]; -llfiID_74 -> llfiID_75 [color="blue"]; -llfiID_75 -> llfiID_76 [color="blue"]; -llfiID_77 -> llfiID_78 [color="blue"]; -llfiID_78 -> llfiID_79 [color="blue"]; -llfiID_79 -> llfiID_80 [color="blue"]; -llfiID_82 -> llfiID_83 [color="blue"]; -llfiID_84 -> llfiID_85 [color="blue"]; -llfiID_86 -> llfiID_87 [color="blue"]; -llfiID_90 -> llfiID_91 [color="blue"]; -subgraph cluster_main_ { -label = "main_"; -llfiID_0 [shape=record,label="0\nalloca\n"]; -llfiID_1 [shape=record,label="1\nalloca\n"]; -llfiID_2 [shape=record,label="2\nalloca\n"]; -llfiID_3 [shape=record,label="3\nalloca\n"]; -llfiID_4 [shape=record,label="4\nalloca\n"]; -llfiID_5 [shape=record,label="5\nalloca\n"]; -llfiID_6 [shape=record,label="6\nalloca\n"]; -llfiID_7 [shape=record,label="7\nstore\n"]; -llfiID_8 [shape=record,label="8\nstore\n"]; -llfiID_9 [shape=record,label="9\nload\n"]; -llfiID_10 [shape=record,label="10\nsext\n"]; -llfiID_11 [shape=record,label="11\ncall\n"]; -llfiID_12 [shape=record,label="12\nstore\n"]; -llfiID_13 [shape=record,label="13\nload\n"]; -llfiID_14 [shape=record,label="14\nicmp\n"]; -llfiID_15 [shape=record,label="15\nbr\n"]; -} -llfiID_0 -> llfiID_1; -llfiID_1 -> llfiID_2; -llfiID_2 -> llfiID_3; -llfiID_3 -> llfiID_4; -llfiID_4 -> llfiID_5; -llfiID_5 -> llfiID_6; -llfiID_6 -> llfiID_7; -llfiID_7 -> llfiID_8; -llfiID_8 -> llfiID_9; -llfiID_9 -> llfiID_10; -llfiID_10 -> llfiID_11; -llfiID_11 -> llfiID_12; -llfiID_12 -> llfiID_13; -llfiID_13 -> llfiID_14; -llfiID_14 -> llfiID_15; -llfiID_15 -> llfiID_16; -llfiID_15 -> llfiID_19; -subgraph cluster_main_ { -label = "main_"; -llfiID_16 [shape=record,label="16\ncall\n"]; -llfiID_17 [shape=record,label="17\nstore\n"]; -llfiID_18 [shape=record,label="18\nbr\n"]; -} -llfiID_16 -> llfiID_17; -llfiID_17 -> llfiID_18; -llfiID_18 -> llfiID_90; -subgraph cluster_main_ { -label = "main_"; -llfiID_19 [shape=record,label="19\nload\n"]; -llfiID_20 [shape=record,label="20\nsext\n"]; -llfiID_21 [shape=record,label="21\ncall\n"]; -llfiID_22 [shape=record,label="22\nstore\n"]; -llfiID_23 [shape=record,label="23\nload\n"]; -llfiID_24 [shape=record,label="24\nicmp\n"]; -llfiID_25 [shape=record,label="25\nbr\n"]; -} -llfiID_19 -> llfiID_20; -llfiID_20 -> llfiID_21; -llfiID_21 -> llfiID_22; -llfiID_22 -> llfiID_23; -llfiID_23 -> llfiID_24; -llfiID_24 -> llfiID_25; -llfiID_25 -> llfiID_26; -llfiID_25 -> llfiID_29; -subgraph cluster_main_ { -label = "main_"; -llfiID_26 [shape=record,label="26\ncall\n"]; -llfiID_27 [shape=record,label="27\nstore\n"]; -llfiID_28 [shape=record,label="28\nbr\n"]; -} -llfiID_26 -> llfiID_27; -llfiID_27 -> llfiID_28; -llfiID_28 -> llfiID_90; -subgraph cluster_main_ { -label = "main_"; -llfiID_29 [shape=record,label="29\ncall\n"]; -llfiID_30 [shape=record,label="30\nstore\n"]; -llfiID_31 [shape=record,label="31\nload\n"]; -llfiID_32 [shape=record,label="32\nicmp\n"]; -llfiID_33 [shape=record,label="33\nbr\n"]; -} -llfiID_29 -> llfiID_30; -llfiID_30 -> llfiID_31; -llfiID_31 -> llfiID_32; -llfiID_32 -> llfiID_33; -llfiID_33 -> llfiID_34; -llfiID_33 -> llfiID_36; -subgraph cluster_main_ { -label = "main_"; -llfiID_34 [shape=record,label="34\ncall\n"]; -llfiID_35 [shape=record,label="35\nbr\n"]; -} -llfiID_34 -> llfiID_35; -llfiID_35 -> llfiID_36; -subgraph cluster_main_ { -label = "main_"; -llfiID_36 [shape=record,label="36\nload\n"]; -llfiID_37 [shape=record,label="37\nload\n"]; -llfiID_38 [shape=record,label="38\nsext\n"]; -llfiID_39 [shape=record,label="39\nload\n"]; -llfiID_40 [shape=record,label="40\ncall\n"]; -llfiID_41 [shape=record,label="41\nload\n"]; -llfiID_42 [shape=record,label="42\ncall\n"]; -llfiID_43 [shape=record,label="43\nload\n"]; -llfiID_44 [shape=record,label="44\ncall\n"]; -llfiID_45 [shape=record,label="45\nload\n"]; -llfiID_46 [shape=record,label="46\nload\n"]; -llfiID_47 [shape=record,label="47\ncall\n"]; -llfiID_48 [shape=record,label="48\nload\n"]; -llfiID_49 [shape=record,label="49\ncall\n"]; -llfiID_50 [shape=record,label="50\ncall\n"]; -llfiID_51 [shape=record,label="51\nstore\n"]; -llfiID_52 [shape=record,label="52\nload\n"]; -llfiID_53 [shape=record,label="53\nicmp\n"]; -llfiID_54 [shape=record,label="54\nbr\n"]; -} -llfiID_36 -> llfiID_37; -llfiID_37 -> llfiID_38; -llfiID_38 -> llfiID_39; -llfiID_39 -> llfiID_40; -llfiID_40 -> llfiID_41; -llfiID_41 -> llfiID_42; -llfiID_42 -> llfiID_43; -llfiID_43 -> llfiID_44; -llfiID_44 -> llfiID_45; -llfiID_45 -> llfiID_46; -llfiID_46 -> llfiID_47; -llfiID_47 -> llfiID_48; -llfiID_48 -> llfiID_49; -llfiID_49 -> llfiID_50; -llfiID_50 -> llfiID_51; -llfiID_51 -> llfiID_52; -llfiID_52 -> llfiID_53; -llfiID_53 -> llfiID_54; -llfiID_54 -> llfiID_55; -llfiID_54 -> llfiID_61; -subgraph cluster_main_ { -label = "main_"; -llfiID_55 [shape=record,label="55\nload\n"]; -llfiID_56 [shape=record,label="56\nload\n"]; -llfiID_57 [shape=record,label="57\nsext\n"]; -llfiID_58 [shape=record,label="58\nload\n"]; -llfiID_59 [shape=record,label="59\ncall\n"]; -llfiID_60 [shape=record,label="60\nbr\n"]; -} -llfiID_55 -> llfiID_56; -llfiID_56 -> llfiID_57; -llfiID_57 -> llfiID_58; -llfiID_58 -> llfiID_59; -llfiID_59 -> llfiID_60; -llfiID_60 -> llfiID_61; -subgraph cluster_main_ { -label = "main_"; -llfiID_61 [shape=record,label="61\nload\n"]; -llfiID_62 [shape=record,label="62\ncall\n"]; -llfiID_63 [shape=record,label="63\ncall\n"]; -llfiID_64 [shape=record,label="64\ncall\n"]; -llfiID_65 [shape=record,label="65\nload\n"]; -llfiID_66 [shape=record,label="66\ncall\n"]; -llfiID_67 [shape=record,label="67\ntrunc\n"]; -llfiID_68 [shape=record,label="68\nstore\n"]; -llfiID_69 [shape=record,label="69\nbr\n"]; -} -llfiID_61 -> llfiID_62; -llfiID_62 -> llfiID_63; -llfiID_63 -> llfiID_64; -llfiID_64 -> llfiID_65; -llfiID_65 -> llfiID_66; -llfiID_66 -> llfiID_67; -llfiID_67 -> llfiID_68; -llfiID_68 -> llfiID_69; -llfiID_69 -> llfiID_70; -subgraph cluster_main_ { -label = "main_"; -llfiID_70 [shape=record,label="70\nload\n"]; -llfiID_71 [shape=record,label="71\nsext\n"]; -llfiID_72 [shape=record,label="72\nicmp\n"]; -llfiID_73 [shape=record,label="73\nbr\n"]; -} -llfiID_70 -> llfiID_71; -llfiID_71 -> llfiID_72; -llfiID_72 -> llfiID_73; -llfiID_73 -> llfiID_74; -llfiID_73 -> llfiID_82; -subgraph cluster_main_ { -label = "main_"; -llfiID_74 [shape=record,label="74\nload\n"]; -llfiID_75 [shape=record,label="75\nsext\n"]; -llfiID_76 [shape=record,label="76\ncall\n"]; -llfiID_77 [shape=record,label="77\nload\n"]; -llfiID_78 [shape=record,label="78\ncall\n"]; -llfiID_79 [shape=record,label="79\ntrunc\n"]; -llfiID_80 [shape=record,label="80\nstore\n"]; -llfiID_81 [shape=record,label="81\nbr\n"]; -} -llfiID_74 -> llfiID_75; -llfiID_75 -> llfiID_76; -llfiID_76 -> llfiID_77; -llfiID_77 -> llfiID_78; -llfiID_78 -> llfiID_79; -llfiID_79 -> llfiID_80; -llfiID_80 -> llfiID_81; -llfiID_81 -> llfiID_70; -subgraph cluster_main_ { -label = "main_"; -llfiID_82 [shape=record,label="82\nload\n"]; -llfiID_83 [shape=record,label="83\ncall\n"]; -llfiID_84 [shape=record,label="84\nload\n"]; -llfiID_85 [shape=record,label="85\ncall\n"]; -llfiID_86 [shape=record,label="86\nload\n"]; -llfiID_87 [shape=record,label="87\ncall\n"]; -llfiID_88 [shape=record,label="88\nstore\n"]; -llfiID_89 [shape=record,label="89\nbr\n"]; -} -llfiID_82 -> llfiID_83; -llfiID_83 -> llfiID_84; -llfiID_84 -> llfiID_85; -llfiID_85 -> llfiID_86; -llfiID_86 -> llfiID_87; -llfiID_87 -> llfiID_88; -llfiID_88 -> llfiID_89; -llfiID_89 -> llfiID_90; -subgraph cluster_main_ { -label = "main_"; -llfiID_90 [shape=record,label="90\nload\n"]; -llfiID_91 [shape=record,label="91\nret\n"]; -} -llfiID_90 -> llfiID_91; -{ rank = sink;Legend [shape=none, margin=0, label=<
Legend
Normal Control Flow solid arrow
Data Dependancy solid arrow
Control Flow Error dashed arrow
Fault Affected Instruction
Fault Injected Instruction red border
>];}} diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.prof.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.prof.txt deleted file mode 100644 index 519277e6..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.prof.txt +++ /dev/null @@ -1,3 +0,0 @@ -# do not edit -# cycle considered the execution cycle of each instruction type -total_cycle=1 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.totalindex.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.totalindex.txt deleted file mode 100644 index 20a0c1c2..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi.stat.totalindex.txt +++ /dev/null @@ -1 +0,0 @@ -totalindex=91 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/golden_std_output b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/golden_std_output deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/golden_std_output +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/llfi.stat.trace.prof.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/llfi.stat.trace.prof.txt deleted file mode 100644 index 0f2cf343..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/llfi.stat.trace.prof.txt +++ /dev/null @@ -1,676 +0,0 @@ -ID: 0 OPCode: alloca Value: 00007fff51caab74 -ID: 1 OPCode: alloca Value: 00007fff51caab70 -ID: 2 OPCode: alloca Value: 00007fff51caab68 -ID: 3 OPCode: alloca Value: 00007fff51caab60 -ID: 4 OPCode: alloca Value: 00007fff51caab58 -ID: 5 OPCode: alloca Value: 00007fff51caab54 -ID: 6 OPCode: alloca Value: 00007fff51caab53 -ID: 7 OPCode: store Value: 00000000 -ID: 8 OPCode: store Value: 00000000 -ID: 9 OPCode: load Value: 00000032 -ID: 10 OPCode: sext Value: 0000000000000032 -ID: 11 OPCode: call Value: 0000000000bc29c0 -ID: 12 OPCode: store Value: 00000000 -ID: 13 OPCode: load Value: 0000000000bc29c0 -ID: 14 OPCode: icmp Value: 00 -ID: 15 OPCode: br Value: 00000000 -ID: 19 OPCode: load Value: 00000032 -ID: 20 OPCode: sext Value: 0000000000000032 -ID: 21 OPCode: call Value: 0000000000bc2a00 -ID: 22 OPCode: store Value: 00000000 -ID: 23 OPCode: load Value: 0000000000bc2a00 -ID: 24 OPCode: icmp Value: 00 -ID: 25 OPCode: br Value: 00000000 -ID: 29 OPCode: call Value: 0000000000bc2a40 -ID: 30 OPCode: store Value: 00000000 -ID: 31 OPCode: load Value: 0000000000bc2a40 -ID: 32 OPCode: icmp Value: 00 -ID: 33 OPCode: br Value: 00000000 -ID: 36 OPCode: load Value: 0000000000bc29c0 -ID: 37 OPCode: load Value: 00000032 -ID: 38 OPCode: sext Value: 0000000000000032 -ID: 39 OPCode: load Value: 0000000000bc2a40 -ID: 40 OPCode: call Value: 0000000000000032 -ID: 41 OPCode: load Value: 0000000000bc2a40 -ID: 42 OPCode: call Value: 00000000 -ID: 43 OPCode: load Value: 0000000000bc29c0 -ID: 44 OPCode: call Value: 00000052 -ID: 45 OPCode: load Value: 0000000000bc2a00 -ID: 46 OPCode: load Value: 0000000000bc29c0 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000000bc2a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000000bc2a40 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000000bc2a40 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000000bc2a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000000bc2a40 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000000bc2a40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000000bc2a40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000000bc2a40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006d -ID: 79 OPCode: trunc Value: 6d -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6d -ID: 71 OPCode: sext Value: 0000006d -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6d -ID: 75 OPCode: sext Value: 0000006d -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000066 -ID: 79 OPCode: trunc Value: 66 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 66 -ID: 71 OPCode: sext Value: 00000066 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 66 -ID: 75 OPCode: sext Value: 00000066 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000077 -ID: 79 OPCode: trunc Value: 77 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 77 -ID: 71 OPCode: sext Value: 00000077 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 77 -ID: 75 OPCode: sext Value: 00000077 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006e -ID: 79 OPCode: trunc Value: 6e -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6e -ID: 71 OPCode: sext Value: 0000006e -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6e -ID: 75 OPCode: sext Value: 0000006e -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000066 -ID: 79 OPCode: trunc Value: 66 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 66 -ID: 71 OPCode: sext Value: 00000066 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 66 -ID: 75 OPCode: sext Value: 00000066 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006d -ID: 79 OPCode: trunc Value: 6d -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6d -ID: 71 OPCode: sext Value: 0000006d -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6d -ID: 75 OPCode: sext Value: 0000006d -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 0000006e -ID: 79 OPCode: trunc Value: 6e -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6e -ID: 71 OPCode: sext Value: 0000006e -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6e -ID: 75 OPCode: sext Value: 0000006e -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000041 -ID: 79 OPCode: trunc Value: 41 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 41 -ID: 71 OPCode: sext Value: 00000041 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 41 -ID: 75 OPCode: sext Value: 00000041 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000050 -ID: 79 OPCode: trunc Value: 50 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 50 -ID: 71 OPCode: sext Value: 00000050 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 50 -ID: 75 OPCode: sext Value: 00000050 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000049 -ID: 79 OPCode: trunc Value: 49 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 49 -ID: 71 OPCode: sext Value: 00000049 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 49 -ID: 75 OPCode: sext Value: 00000049 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000bc2a40 -ID: 78 OPCode: call Value: ffffffff -ID: 79 OPCode: trunc Value: ff -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: ff -ID: 71 OPCode: sext Value: ffffffff -ID: 72 OPCode: icmp Value: 00 -ID: 73 OPCode: br Value: 00000000 -ID: 82 OPCode: load Value: 0000000000bc2a40 -ID: 83 OPCode: call Value: 00000000 -ID: 84 OPCode: load Value: 0000000000bc29c0 -ID: 85 OPCode: call Value: 00000000 -ID: 86 OPCode: load Value: 0000000000bc2a00 -ID: 87 OPCode: call Value: 00000000 -ID: 88 OPCode: store Value: 00000000 -ID: 89 OPCode: br Value: 00000000 -ID: 90 OPCode: load Value: 00000000 -ID: 91 OPCode: ret Value: 00000000 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/output.prof.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/output.prof.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/baseline/output.prof.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-0.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-0.txt deleted file mode 100644 index bc3f39a1..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-0.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflowMemmove(Data), fi_index=47, fi_cycle=0, fi_reg_index=0, fi_bit=48 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-1.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-1.txt deleted file mode 100644 index 070c53d7..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-1.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflowMemmove(Data), fi_index=47, fi_cycle=0, fi_reg_index=0, fi_bit=62 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-2.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-2.txt deleted file mode 100644 index 1687d637..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-2.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflowMemmove(Data), fi_index=47, fi_cycle=0, fi_reg_index=0, fi_bit=49 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-3.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-3.txt deleted file mode 100644 index 15255f1f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-3.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflowMemmove(Data), fi_index=47, fi_cycle=0, fi_reg_index=0, fi_bit=39 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-4.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-4.txt deleted file mode 100644 index 76c123e9..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-4.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflowMemmove(Data), fi_index=47, fi_cycle=0, fi_reg_index=0, fi_bit=42 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-0.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-0.txt deleted file mode 100644 index 49f6bf0f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-0.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 40 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000002388a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000002388cc0 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000002388cc0 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000002388a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000002388cc0 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000002388cc0 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000002388cc0 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000002388cc0 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000002388cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-1.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-1.txt deleted file mode 100644 index c6a7d259..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-1.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 40 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000000c26a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000000c26cc0 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000000c26cc0 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000000c26a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000000c26cc0 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000000c26cc0 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000000c26cc0 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000000c26cc0 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000000c26cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-2.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-2.txt deleted file mode 100644 index 6ba1f259..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-2.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 40 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 00000000023bfa00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 00000000023bfcc0 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 00000000023bfcc0 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 00000000023bfa00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 00000000023bfcc0 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 00000000023bfcc0 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 00000000023bfcc0 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 00000000023bfcc0 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000023bfcc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-3.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-3.txt deleted file mode 100644 index 7bfd92c8..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-3.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 40 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000001e59a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000001e59cc0 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000001e59cc0 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000001e59a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000001e59cc0 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000001e59cc0 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000001e59cc0 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000001e59cc0 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001e59cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-4.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-4.txt deleted file mode 100644 index 8c9aca40..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/llfi_stat_output/llfi.stat.trace.0-4.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 40 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000001286a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000001286cc0 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000001286cc0 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000001286a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000001286cc0 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000001286cc0 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000001286cc0 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000001286cc0 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001286cc0 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-0.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-0.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-0.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-1.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-1.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-1.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-2.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-2.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-2.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-3.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-3.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-3.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-4.txt b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-4.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/prog_output/output.0-4.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-0 b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-0 deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-0 +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-1 b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-1 deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-1 +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-2 b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-2 deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-2 +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-3 b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-3 deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-3 +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-4 b/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-4 deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/llfi/std_output/std_outputfile-run-0-4 +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflowMemmove_Data/sample.txt b/test_suite/Traces/BufferOverflowMemmove_Data/sample.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflowMemmove_Data/sample.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflow_API/input.yaml b/test_suite/Traces/BufferOverflow_API/input.yaml deleted file mode 100644 index d4da9e3b..00000000 --- a/test_suite/Traces/BufferOverflow_API/input.yaml +++ /dev/null @@ -1,27 +0,0 @@ - -kernelOption: - - forceRun - -compileOption: - instSelMethod: - - customInstselector: - include: - - BufferOverflow(API) - - regSelMethod: customregselector - customRegSelector: SoftwareFault - - includeInjectionTrace: - - backward - - tracingPropagation: True # trace dynamic instruction values. - - tracingPropagationOption: - maxTrace: 250 # max number of instructions to trace during fault injection run - debugTrace: False - generateCDFG: True - -runOption: - - run: - numOfRuns: 5 - fi_type: SoftwareFault diff --git a/test_suite/Traces/BufferOverflow_API/llfi.config.compiletime.txt b/test_suite/Traces/BufferOverflow_API/llfi.config.compiletime.txt deleted file mode 100644 index a49fbe65..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi.config.compiletime.txt +++ /dev/null @@ -1,4 +0,0 @@ -failure_class=API -failure_mode=BufferOverflow -targets=fread()/fwrite() -injector=ChangeValueInjector diff --git a/test_suite/Traces/BufferOverflow_API/llfi.config.rumtime.txt b/test_suite/Traces/BufferOverflow_API/llfi.config.rumtime.txt deleted file mode 100644 index 1897c07d..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi.config.rumtime.txt +++ /dev/null @@ -1,2 +0,0 @@ -fi_cycle=1 -fi_type=BufferOverflow(API) diff --git a/test_suite/Traces/BufferOverflow_API/llfi.log.compilation.txt b/test_suite/Traces/BufferOverflow_API/llfi.log.compilation.txt deleted file mode 100644 index b28c88dd..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi.log.compilation.txt +++ /dev/null @@ -1,30 +0,0 @@ - - -Start of a pass -The selected instruction %size = alloca i32, align 4, !llfi_index !2does not have any valid registers for fault injection -The selected instruction %src = alloca i8*, align 8, !llfi_index !3does not have any valid registers for fault injection -The selected instruction %dest = alloca i8*, align 8, !llfi_index !4does not have any valid registers for fault injection -The selected instruction %fp = alloca %struct._IO_FILE*, align 8, !llfi_index !5does not have any valid registers for fault injection -The selected instruction %24 = load i8** %src, align 8, !llfi_index !37does not have any valid registers for fault injection -The selected instruction %25 = load i32* %size, align 4, !llfi_index !38does not have any valid registers for fault injection -The selected instruction %26 = sext i32 %25 to i64, !llfi_index !39does not have any valid registers for fault injection -The selected instruction %27 = load %struct._IO_FILE** %fp, align 8, !llfi_index !40does not have any valid registers for fault injection -The selected instruction %41 = load i8** %dest, align 8, !llfi_index !56does not have any valid registers for fault injection -The selected instruction %42 = load i32* %size, align 4, !llfi_index !57does not have any valid registers for fault injection -The selected instruction %43 = sext i32 %42 to i64, !llfi_index !58does not have any valid registers for fault injection -The selected instruction %44 = load %struct._IO_FILE** %fp, align 8, !llfi_index !59does not have any valid registers for fault injection - - -Start of a pass -The selected instruction %size = alloca i32, align 4, !llfi_index !2does not have any valid registers for fault injection -The selected instruction %src = alloca i8*, align 8, !llfi_index !3does not have any valid registers for fault injection -The selected instruction %dest = alloca i8*, align 8, !llfi_index !4does not have any valid registers for fault injection -The selected instruction %fp = alloca %struct._IO_FILE*, align 8, !llfi_index !5does not have any valid registers for fault injection -The selected instruction %24 = load i8** %src, align 8, !llfi_index !37does not have any valid registers for fault injection -The selected instruction %25 = load i32* %size, align 4, !llfi_index !38does not have any valid registers for fault injection -The selected instruction %26 = sext i32 %25 to i64, !llfi_index !39does not have any valid registers for fault injection -The selected instruction %27 = load %struct._IO_FILE** %fp, align 8, !llfi_index !40does not have any valid registers for fault injection -The selected instruction %41 = load i8** %dest, align 8, !llfi_index !56does not have any valid registers for fault injection -The selected instruction %42 = load i32* %size, align 4, !llfi_index !57does not have any valid registers for fault injection -The selected instruction %43 = sext i32 %42 to i64, !llfi_index !58does not have any valid registers for fault injection -The selected instruction %44 = load %struct._IO_FILE** %fp, align 8, !llfi_index !59does not have any valid registers for fault injection diff --git a/test_suite/Traces/BufferOverflow_API/llfi.stat.graph.dot b/test_suite/Traces/BufferOverflow_API/llfi.stat.graph.dot deleted file mode 100644 index 00bfe0d5..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi.stat.graph.dot +++ /dev/null @@ -1,311 +0,0 @@ -digraph "LLFI Program Graph" { -llfiID_0 -> llfiID_90 [color="blue"]; -llfiID_0 -> llfiID_88 [color="blue"]; -llfiID_0 -> llfiID_27 [color="blue"]; -llfiID_0 -> llfiID_17 [color="blue"]; -llfiID_0 -> llfiID_7 [color="blue"]; -llfiID_1 -> llfiID_56 [color="blue"]; -llfiID_1 -> llfiID_37 [color="blue"]; -llfiID_1 -> llfiID_19 [color="blue"]; -llfiID_1 -> llfiID_9 [color="blue"]; -llfiID_1 -> llfiID_8 [color="blue"]; -llfiID_2 -> llfiID_84 [color="blue"]; -llfiID_2 -> llfiID_46 [color="blue"]; -llfiID_2 -> llfiID_43 [color="blue"]; -llfiID_2 -> llfiID_36 [color="blue"]; -llfiID_2 -> llfiID_13 [color="blue"]; -llfiID_2 -> llfiID_12 [color="blue"]; -llfiID_3 -> llfiID_86 [color="blue"]; -llfiID_3 -> llfiID_55 [color="blue"]; -llfiID_3 -> llfiID_48 [color="blue"]; -llfiID_3 -> llfiID_45 [color="blue"]; -llfiID_3 -> llfiID_23 [color="blue"]; -llfiID_3 -> llfiID_22 [color="blue"]; -llfiID_4 -> llfiID_82 [color="blue"]; -llfiID_4 -> llfiID_77 [color="blue"]; -llfiID_4 -> llfiID_65 [color="blue"]; -llfiID_4 -> llfiID_61 [color="blue"]; -llfiID_4 -> llfiID_58 [color="blue"]; -llfiID_4 -> llfiID_52 [color="blue"]; -llfiID_4 -> llfiID_51 [color="blue"]; -llfiID_4 -> llfiID_41 [color="blue"]; -llfiID_4 -> llfiID_39 [color="blue"]; -llfiID_4 -> llfiID_31 [color="blue"]; -llfiID_4 -> llfiID_30 [color="blue"]; -llfiID_6 -> llfiID_80 [color="blue"]; -llfiID_6 -> llfiID_74 [color="blue"]; -llfiID_6 -> llfiID_70 [color="blue"]; -llfiID_6 -> llfiID_68 [color="blue"]; -llfiID_9 -> llfiID_10 [color="blue"]; -llfiID_10 -> llfiID_11 [color="blue"]; -llfiID_11 -> llfiID_12 [color="blue"]; -llfiID_13 -> llfiID_14 [color="blue"]; -llfiID_14 -> llfiID_15 [color="blue"]; -llfiID_19 -> llfiID_20 [color="blue"]; -llfiID_20 -> llfiID_21 [color="blue"]; -llfiID_21 -> llfiID_22 [color="blue"]; -llfiID_23 -> llfiID_24 [color="blue"]; -llfiID_24 -> llfiID_25 [color="blue"]; -llfiID_29 -> llfiID_30 [color="blue"]; -llfiID_31 -> llfiID_32 [color="blue"]; -llfiID_32 -> llfiID_33 [color="blue"]; -llfiID_36 -> llfiID_40 [color="blue"]; -llfiID_37 -> llfiID_38 [color="blue"]; -llfiID_38 -> llfiID_40 [color="blue"]; -llfiID_39 -> llfiID_40 [color="blue"]; -llfiID_41 -> llfiID_42 [color="blue"]; -llfiID_43 -> llfiID_44 [color="blue"]; -llfiID_45 -> llfiID_47 [color="blue"]; -llfiID_46 -> llfiID_47 [color="blue"]; -llfiID_48 -> llfiID_49 [color="blue"]; -llfiID_50 -> llfiID_51 [color="blue"]; -llfiID_52 -> llfiID_53 [color="blue"]; -llfiID_53 -> llfiID_54 [color="blue"]; -llfiID_55 -> llfiID_59 [color="blue"]; -llfiID_56 -> llfiID_57 [color="blue"]; -llfiID_57 -> llfiID_59 [color="blue"]; -llfiID_58 -> llfiID_59 [color="blue"]; -llfiID_61 -> llfiID_62 [color="blue"]; -llfiID_65 -> llfiID_66 [color="blue"]; -llfiID_66 -> llfiID_67 [color="blue"]; -llfiID_67 -> llfiID_68 [color="blue"]; -llfiID_70 -> llfiID_71 [color="blue"]; -llfiID_71 -> llfiID_72 [color="blue"]; -llfiID_72 -> llfiID_73 [color="blue"]; -llfiID_74 -> llfiID_75 [color="blue"]; -llfiID_75 -> llfiID_76 [color="blue"]; -llfiID_77 -> llfiID_78 [color="blue"]; -llfiID_78 -> llfiID_79 [color="blue"]; -llfiID_79 -> llfiID_80 [color="blue"]; -llfiID_82 -> llfiID_83 [color="blue"]; -llfiID_84 -> llfiID_85 [color="blue"]; -llfiID_86 -> llfiID_87 [color="blue"]; -llfiID_90 -> llfiID_91 [color="blue"]; -subgraph cluster_main_ { -label = "main_"; -llfiID_0 [shape=record,label="0\nalloca\n"]; -llfiID_1 [shape=record,label="1\nalloca\n"]; -llfiID_2 [shape=record,label="2\nalloca\n"]; -llfiID_3 [shape=record,label="3\nalloca\n"]; -llfiID_4 [shape=record,label="4\nalloca\n"]; -llfiID_5 [shape=record,label="5\nalloca\n"]; -llfiID_6 [shape=record,label="6\nalloca\n"]; -llfiID_7 [shape=record,label="7\nstore\n"]; -llfiID_8 [shape=record,label="8\nstore\n"]; -llfiID_9 [shape=record,label="9\nload\n"]; -llfiID_10 [shape=record,label="10\nsext\n"]; -llfiID_11 [shape=record,label="11\ncall\n"]; -llfiID_12 [shape=record,label="12\nstore\n"]; -llfiID_13 [shape=record,label="13\nload\n"]; -llfiID_14 [shape=record,label="14\nicmp\n"]; -llfiID_15 [shape=record,label="15\nbr\n"]; -} -llfiID_0 -> llfiID_1; -llfiID_1 -> llfiID_2; -llfiID_2 -> llfiID_3; -llfiID_3 -> llfiID_4; -llfiID_4 -> llfiID_5; -llfiID_5 -> llfiID_6; -llfiID_6 -> llfiID_7; -llfiID_7 -> llfiID_8; -llfiID_8 -> llfiID_9; -llfiID_9 -> llfiID_10; -llfiID_10 -> llfiID_11; -llfiID_11 -> llfiID_12; -llfiID_12 -> llfiID_13; -llfiID_13 -> llfiID_14; -llfiID_14 -> llfiID_15; -llfiID_15 -> llfiID_16; -llfiID_15 -> llfiID_19; -subgraph cluster_main_ { -label = "main_"; -llfiID_16 [shape=record,label="16\ncall\n"]; -llfiID_17 [shape=record,label="17\nstore\n"]; -llfiID_18 [shape=record,label="18\nbr\n"]; -} -llfiID_16 -> llfiID_17; -llfiID_17 -> llfiID_18; -llfiID_18 -> llfiID_90; -subgraph cluster_main_ { -label = "main_"; -llfiID_19 [shape=record,label="19\nload\n"]; -llfiID_20 [shape=record,label="20\nsext\n"]; -llfiID_21 [shape=record,label="21\ncall\n"]; -llfiID_22 [shape=record,label="22\nstore\n"]; -llfiID_23 [shape=record,label="23\nload\n"]; -llfiID_24 [shape=record,label="24\nicmp\n"]; -llfiID_25 [shape=record,label="25\nbr\n"]; -} -llfiID_19 -> llfiID_20; -llfiID_20 -> llfiID_21; -llfiID_21 -> llfiID_22; -llfiID_22 -> llfiID_23; -llfiID_23 -> llfiID_24; -llfiID_24 -> llfiID_25; -llfiID_25 -> llfiID_26; -llfiID_25 -> llfiID_29; -subgraph cluster_main_ { -label = "main_"; -llfiID_26 [shape=record,label="26\ncall\n"]; -llfiID_27 [shape=record,label="27\nstore\n"]; -llfiID_28 [shape=record,label="28\nbr\n"]; -} -llfiID_26 -> llfiID_27; -llfiID_27 -> llfiID_28; -llfiID_28 -> llfiID_90; -subgraph cluster_main_ { -label = "main_"; -llfiID_29 [shape=record,label="29\ncall\n"]; -llfiID_30 [shape=record,label="30\nstore\n"]; -llfiID_31 [shape=record,label="31\nload\n"]; -llfiID_32 [shape=record,label="32\nicmp\n"]; -llfiID_33 [shape=record,label="33\nbr\n"]; -} -llfiID_29 -> llfiID_30; -llfiID_30 -> llfiID_31; -llfiID_31 -> llfiID_32; -llfiID_32 -> llfiID_33; -llfiID_33 -> llfiID_34; -llfiID_33 -> llfiID_36; -subgraph cluster_main_ { -label = "main_"; -llfiID_34 [shape=record,label="34\ncall\n"]; -llfiID_35 [shape=record,label="35\nbr\n"]; -} -llfiID_34 -> llfiID_35; -llfiID_35 -> llfiID_36; -subgraph cluster_main_ { -label = "main_"; -llfiID_36 [shape=record,label="36\nload\n"]; -llfiID_37 [shape=record,label="37\nload\n"]; -llfiID_38 [shape=record,label="38\nsext\n"]; -llfiID_39 [shape=record,label="39\nload\n"]; -llfiID_40 [shape=record,label="40\ncall\n"]; -llfiID_41 [shape=record,label="41\nload\n"]; -llfiID_42 [shape=record,label="42\ncall\n"]; -llfiID_43 [shape=record,label="43\nload\n"]; -llfiID_44 [shape=record,label="44\ncall\n"]; -llfiID_45 [shape=record,label="45\nload\n"]; -llfiID_46 [shape=record,label="46\nload\n"]; -llfiID_47 [shape=record,label="47\ncall\n"]; -llfiID_48 [shape=record,label="48\nload\n"]; -llfiID_49 [shape=record,label="49\ncall\n"]; -llfiID_50 [shape=record,label="50\ncall\n"]; -llfiID_51 [shape=record,label="51\nstore\n"]; -llfiID_52 [shape=record,label="52\nload\n"]; -llfiID_53 [shape=record,label="53\nicmp\n"]; -llfiID_54 [shape=record,label="54\nbr\n"]; -} -llfiID_36 -> llfiID_37; -llfiID_37 -> llfiID_38; -llfiID_38 -> llfiID_39; -llfiID_39 -> llfiID_40; -llfiID_40 -> llfiID_41; -llfiID_41 -> llfiID_42; -llfiID_42 -> llfiID_43; -llfiID_43 -> llfiID_44; -llfiID_44 -> llfiID_45; -llfiID_45 -> llfiID_46; -llfiID_46 -> llfiID_47; -llfiID_47 -> llfiID_48; -llfiID_48 -> llfiID_49; -llfiID_49 -> llfiID_50; -llfiID_50 -> llfiID_51; -llfiID_51 -> llfiID_52; -llfiID_52 -> llfiID_53; -llfiID_53 -> llfiID_54; -llfiID_54 -> llfiID_55; -llfiID_54 -> llfiID_61; -subgraph cluster_main_ { -label = "main_"; -llfiID_55 [shape=record,label="55\nload\n"]; -llfiID_56 [shape=record,label="56\nload\n"]; -llfiID_57 [shape=record,label="57\nsext\n"]; -llfiID_58 [shape=record,label="58\nload\n"]; -llfiID_59 [shape=record,label="59\ncall\n"]; -llfiID_60 [shape=record,label="60\nbr\n"]; -} -llfiID_55 -> llfiID_56; -llfiID_56 -> llfiID_57; -llfiID_57 -> llfiID_58; -llfiID_58 -> llfiID_59; -llfiID_59 -> llfiID_60; -llfiID_60 -> llfiID_61; -subgraph cluster_main_ { -label = "main_"; -llfiID_61 [shape=record,label="61\nload\n"]; -llfiID_62 [shape=record,label="62\ncall\n"]; -llfiID_63 [shape=record,label="63\ncall\n"]; -llfiID_64 [shape=record,label="64\ncall\n"]; -llfiID_65 [shape=record,label="65\nload\n"]; -llfiID_66 [shape=record,label="66\ncall\n"]; -llfiID_67 [shape=record,label="67\ntrunc\n"]; -llfiID_68 [shape=record,label="68\nstore\n"]; -llfiID_69 [shape=record,label="69\nbr\n"]; -} -llfiID_61 -> llfiID_62; -llfiID_62 -> llfiID_63; -llfiID_63 -> llfiID_64; -llfiID_64 -> llfiID_65; -llfiID_65 -> llfiID_66; -llfiID_66 -> llfiID_67; -llfiID_67 -> llfiID_68; -llfiID_68 -> llfiID_69; -llfiID_69 -> llfiID_70; -subgraph cluster_main_ { -label = "main_"; -llfiID_70 [shape=record,label="70\nload\n"]; -llfiID_71 [shape=record,label="71\nsext\n"]; -llfiID_72 [shape=record,label="72\nicmp\n"]; -llfiID_73 [shape=record,label="73\nbr\n"]; -} -llfiID_70 -> llfiID_71; -llfiID_71 -> llfiID_72; -llfiID_72 -> llfiID_73; -llfiID_73 -> llfiID_74; -llfiID_73 -> llfiID_82; -subgraph cluster_main_ { -label = "main_"; -llfiID_74 [shape=record,label="74\nload\n"]; -llfiID_75 [shape=record,label="75\nsext\n"]; -llfiID_76 [shape=record,label="76\ncall\n"]; -llfiID_77 [shape=record,label="77\nload\n"]; -llfiID_78 [shape=record,label="78\ncall\n"]; -llfiID_79 [shape=record,label="79\ntrunc\n"]; -llfiID_80 [shape=record,label="80\nstore\n"]; -llfiID_81 [shape=record,label="81\nbr\n"]; -} -llfiID_74 -> llfiID_75; -llfiID_75 -> llfiID_76; -llfiID_76 -> llfiID_77; -llfiID_77 -> llfiID_78; -llfiID_78 -> llfiID_79; -llfiID_79 -> llfiID_80; -llfiID_80 -> llfiID_81; -llfiID_81 -> llfiID_70; -subgraph cluster_main_ { -label = "main_"; -llfiID_82 [shape=record,label="82\nload\n"]; -llfiID_83 [shape=record,label="83\ncall\n"]; -llfiID_84 [shape=record,label="84\nload\n"]; -llfiID_85 [shape=record,label="85\ncall\n"]; -llfiID_86 [shape=record,label="86\nload\n"]; -llfiID_87 [shape=record,label="87\ncall\n"]; -llfiID_88 [shape=record,label="88\nstore\n"]; -llfiID_89 [shape=record,label="89\nbr\n"]; -} -llfiID_82 -> llfiID_83; -llfiID_83 -> llfiID_84; -llfiID_84 -> llfiID_85; -llfiID_85 -> llfiID_86; -llfiID_86 -> llfiID_87; -llfiID_87 -> llfiID_88; -llfiID_88 -> llfiID_89; -llfiID_89 -> llfiID_90; -subgraph cluster_main_ { -label = "main_"; -llfiID_90 [shape=record,label="90\nload\n"]; -llfiID_91 [shape=record,label="91\nret\n"]; -} -llfiID_90 -> llfiID_91; -{ rank = sink;Legend [shape=none, margin=0, label=<
Legend
Normal Control Flow solid arrow
Data Dependancy solid arrow
Control Flow Error dashed arrow
Fault Affected Instruction
Fault Injected Instruction red border
>];}} diff --git a/test_suite/Traces/BufferOverflow_API/llfi.stat.prof.txt b/test_suite/Traces/BufferOverflow_API/llfi.stat.prof.txt deleted file mode 100644 index be2f18f1..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi.stat.prof.txt +++ /dev/null @@ -1,3 +0,0 @@ -# do not edit -# cycle considered the execution cycle of each instruction type -total_cycle=2 diff --git a/test_suite/Traces/BufferOverflow_API/llfi.stat.totalindex.txt b/test_suite/Traces/BufferOverflow_API/llfi.stat.totalindex.txt deleted file mode 100644 index 20a0c1c2..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi.stat.totalindex.txt +++ /dev/null @@ -1 +0,0 @@ -totalindex=91 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/baseline/golden_std_output b/test_suite/Traces/BufferOverflow_API/llfi/baseline/golden_std_output deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/baseline/golden_std_output +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflow_API/llfi/baseline/llfi.stat.trace.prof.txt b/test_suite/Traces/BufferOverflow_API/llfi/baseline/llfi.stat.trace.prof.txt deleted file mode 100644 index f1d6712f..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/baseline/llfi.stat.trace.prof.txt +++ /dev/null @@ -1,676 +0,0 @@ -ID: 0 OPCode: alloca Value: 00007fffb1d16604 -ID: 1 OPCode: alloca Value: 00007fffb1d16600 -ID: 2 OPCode: alloca Value: 00007fffb1d165f8 -ID: 3 OPCode: alloca Value: 00007fffb1d165f0 -ID: 4 OPCode: alloca Value: 00007fffb1d165e8 -ID: 5 OPCode: alloca Value: 00007fffb1d165e4 -ID: 6 OPCode: alloca Value: 00007fffb1d165e3 -ID: 7 OPCode: store Value: 00000000 -ID: 8 OPCode: store Value: 00000000 -ID: 9 OPCode: load Value: 00000032 -ID: 10 OPCode: sext Value: 0000000000000032 -ID: 11 OPCode: call Value: 00000000010419c0 -ID: 12 OPCode: store Value: 00000000 -ID: 13 OPCode: load Value: 00000000010419c0 -ID: 14 OPCode: icmp Value: 00 -ID: 15 OPCode: br Value: 00000000 -ID: 19 OPCode: load Value: 00000032 -ID: 20 OPCode: sext Value: 0000000000000032 -ID: 21 OPCode: call Value: 0000000001041a00 -ID: 22 OPCode: store Value: 00000000 -ID: 23 OPCode: load Value: 0000000001041a00 -ID: 24 OPCode: icmp Value: 00 -ID: 25 OPCode: br Value: 00000000 -ID: 29 OPCode: call Value: 0000000001041a40 -ID: 30 OPCode: store Value: 00000000 -ID: 31 OPCode: load Value: 0000000001041a40 -ID: 32 OPCode: icmp Value: 00 -ID: 33 OPCode: br Value: 00000000 -ID: 36 OPCode: load Value: 00000000010419c0 -ID: 37 OPCode: load Value: 00000032 -ID: 38 OPCode: sext Value: 0000000000000032 -ID: 39 OPCode: load Value: 0000000001041a40 -ID: 40 OPCode: call Value: 0000000000000032 -ID: 41 OPCode: load Value: 0000000001041a40 -ID: 42 OPCode: call Value: 00000000 -ID: 43 OPCode: load Value: 00000000010419c0 -ID: 44 OPCode: call Value: 00000052 -ID: 45 OPCode: load Value: 0000000001041a00 -ID: 46 OPCode: load Value: 00000000010419c0 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000001041a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000001041a40 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000001041a40 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000001041a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000001041a40 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000001041a40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000001041a40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000001041a40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006d -ID: 79 OPCode: trunc Value: 6d -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6d -ID: 71 OPCode: sext Value: 0000006d -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6d -ID: 75 OPCode: sext Value: 0000006d -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000066 -ID: 79 OPCode: trunc Value: 66 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 66 -ID: 71 OPCode: sext Value: 00000066 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 66 -ID: 75 OPCode: sext Value: 00000066 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000077 -ID: 79 OPCode: trunc Value: 77 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 77 -ID: 71 OPCode: sext Value: 00000077 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 77 -ID: 75 OPCode: sext Value: 00000077 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006e -ID: 79 OPCode: trunc Value: 6e -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6e -ID: 71 OPCode: sext Value: 0000006e -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6e -ID: 75 OPCode: sext Value: 0000006e -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000066 -ID: 79 OPCode: trunc Value: 66 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 66 -ID: 71 OPCode: sext Value: 00000066 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 66 -ID: 75 OPCode: sext Value: 00000066 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006d -ID: 79 OPCode: trunc Value: 6d -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6d -ID: 71 OPCode: sext Value: 0000006d -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6d -ID: 75 OPCode: sext Value: 0000006d -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 0000006e -ID: 79 OPCode: trunc Value: 6e -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6e -ID: 71 OPCode: sext Value: 0000006e -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6e -ID: 75 OPCode: sext Value: 0000006e -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000041 -ID: 79 OPCode: trunc Value: 41 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 41 -ID: 71 OPCode: sext Value: 00000041 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 41 -ID: 75 OPCode: sext Value: 00000041 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000050 -ID: 79 OPCode: trunc Value: 50 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 50 -ID: 71 OPCode: sext Value: 00000050 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 50 -ID: 75 OPCode: sext Value: 00000050 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000049 -ID: 79 OPCode: trunc Value: 49 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 49 -ID: 71 OPCode: sext Value: 00000049 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 49 -ID: 75 OPCode: sext Value: 00000049 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: 00000021 -ID: 79 OPCode: trunc Value: 21 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 21 -ID: 71 OPCode: sext Value: 00000021 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 21 -ID: 75 OPCode: sext Value: 00000021 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001041a40 -ID: 78 OPCode: call Value: ffffffff -ID: 79 OPCode: trunc Value: ff -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: ff -ID: 71 OPCode: sext Value: ffffffff -ID: 72 OPCode: icmp Value: 00 -ID: 73 OPCode: br Value: 00000000 -ID: 82 OPCode: load Value: 0000000001041a40 -ID: 83 OPCode: call Value: 00000000 -ID: 84 OPCode: load Value: 00000000010419c0 -ID: 85 OPCode: call Value: 00000000 -ID: 86 OPCode: load Value: 0000000001041a00 -ID: 87 OPCode: call Value: 00000000 -ID: 88 OPCode: store Value: 00000000 -ID: 89 OPCode: br Value: 00000000 -ID: 90 OPCode: load Value: 00000000 -ID: 91 OPCode: ret Value: 00000000 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/baseline/output.prof.txt b/test_suite/Traces/BufferOverflow_API/llfi/baseline/output.prof.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/baseline/output.prof.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-0.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-0.txt deleted file mode 100644 index 38829641..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-0.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflow(API), fi_index=59, fi_cycle=1, fi_reg_index=0, fi_bit=63 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-1.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-1.txt deleted file mode 100644 index b303dd62..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-1.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflow(API), fi_index=40, fi_cycle=0, fi_reg_index=0, fi_bit=12 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-2.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-2.txt deleted file mode 100644 index 00e4976d..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-2.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflow(API), fi_index=59, fi_cycle=1, fi_reg_index=0, fi_bit=55 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-3.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-3.txt deleted file mode 100644 index 1d2a2083..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-3.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflow(API), fi_index=59, fi_cycle=1, fi_reg_index=0, fi_bit=14 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-4.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-4.txt deleted file mode 100644 index 9e1ff590..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.fi.injectedfaults.0-4.txt +++ /dev/null @@ -1 +0,0 @@ -FI stat: fi_type=BufferOverflow(API), fi_index=59, fi_cycle=1, fi_reg_index=0, fi_bit=43 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-0.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-0.txt deleted file mode 100644 index 5602e4cc..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-0.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 52 -ID: 59 OPCode: call Value: 000000000000005f -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000001d69a40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000001d69a40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000001d69a40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001d69a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-1.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-1.txt deleted file mode 100644 index 14c8d679..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-1.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 33 -ID: 40 OPCode: call Value: 0000000000000032 -ID: 41 OPCode: load Value: 0000000001272a40 -ID: 42 OPCode: call Value: 00000000 -ID: 43 OPCode: load Value: 00000000012729c0 -ID: 44 OPCode: call Value: 00000052 -ID: 45 OPCode: load Value: 0000000001272a00 -ID: 46 OPCode: load Value: 00000000012729c0 -ID: 47 OPCode: call Value: 00000000 -ID: 48 OPCode: load Value: 0000000001272a00 -ID: 49 OPCode: call Value: 0000004d -ID: 50 OPCode: call Value: 0000000001272a40 -ID: 51 OPCode: store Value: 00000000 -ID: 52 OPCode: load Value: 0000000001272a40 -ID: 53 OPCode: icmp Value: 01 -ID: 54 OPCode: br Value: 00000000 -ID: 55 OPCode: load Value: 0000000001272a00 -ID: 56 OPCode: load Value: 00000032 -ID: 57 OPCode: sext Value: 0000000000000032 -ID: 58 OPCode: load Value: 0000000001272a40 -ID: 59 OPCode: call Value: 0000000000000032 -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000001272a40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000001272a40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000001272a40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001272a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-2.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-2.txt deleted file mode 100644 index fd1a4957..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-2.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 52 -ID: 59 OPCode: call Value: 000000000000005f -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 0000000001b25a40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 0000000001b25a40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 0000000001b25a40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 0000000001b25a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-3.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-3.txt deleted file mode 100644 index 4774f21d..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-3.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 52 -ID: 59 OPCode: call Value: 000000000000005f -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 00000000016b5a40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 00000000016b5a40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 00000000016b5a40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 00000000016b5a40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-4.txt b/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-4.txt deleted file mode 100644 index d572bcab..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/llfi_stat_output/llfi.stat.trace.0-4.txt +++ /dev/null @@ -1,251 +0,0 @@ -#TraceStartInstNumber: 52 -ID: 59 OPCode: call Value: 000000000000005f -ID: 60 OPCode: br Value: 00000000 -ID: 61 OPCode: load Value: 000000000150ba40 -ID: 62 OPCode: call Value: 00000000 -ID: 63 OPCode: call Value: 000000000150ba40 -ID: 64 OPCode: call Value: 00000021 -ID: 65 OPCode: load Value: 000000000150ba40 -ID: 66 OPCode: call Value: 00000074 -ID: 67 OPCode: trunc Value: 74 -ID: 68 OPCode: store Value: 00000000 -ID: 69 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000068 -ID: 79 OPCode: trunc Value: 68 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 68 -ID: 71 OPCode: sext Value: 00000068 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 68 -ID: 75 OPCode: sext Value: 00000068 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000069 -ID: 79 OPCode: trunc Value: 69 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 69 -ID: 71 OPCode: sext Value: 00000069 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 69 -ID: 75 OPCode: sext Value: 00000069 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 61 -ID: 71 OPCode: sext Value: 00000061 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 61 -ID: 75 OPCode: sext Value: 00000061 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000065 -ID: 79 OPCode: trunc Value: 65 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 65 -ID: 71 OPCode: sext Value: 00000065 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 65 -ID: 75 OPCode: sext Value: 00000065 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000073 -ID: 79 OPCode: trunc Value: 73 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 73 -ID: 71 OPCode: sext Value: 00000073 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 73 -ID: 75 OPCode: sext Value: 00000073 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000074 -ID: 79 OPCode: trunc Value: 74 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 74 -ID: 71 OPCode: sext Value: 00000074 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 74 -ID: 75 OPCode: sext Value: 00000074 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000020 -ID: 79 OPCode: trunc Value: 20 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 20 -ID: 71 OPCode: sext Value: 00000020 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 20 -ID: 75 OPCode: sext Value: 00000020 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000070 -ID: 79 OPCode: trunc Value: 70 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 70 -ID: 71 OPCode: sext Value: 00000070 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 70 -ID: 75 OPCode: sext Value: 00000070 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 0000006f -ID: 79 OPCode: trunc Value: 6f -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 6f -ID: 71 OPCode: sext Value: 0000006f -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 6f -ID: 75 OPCode: sext Value: 0000006f -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000067 -ID: 79 OPCode: trunc Value: 67 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 67 -ID: 71 OPCode: sext Value: 00000067 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 67 -ID: 75 OPCode: sext Value: 00000067 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000072 -ID: 79 OPCode: trunc Value: 72 -ID: 80 OPCode: store Value: 00000000 -ID: 81 OPCode: br Value: 00000000 -ID: 70 OPCode: load Value: 72 -ID: 71 OPCode: sext Value: 00000072 -ID: 72 OPCode: icmp Value: 01 -ID: 73 OPCode: br Value: 00000000 -ID: 74 OPCode: load Value: 72 -ID: 75 OPCode: sext Value: 00000072 -ID: 76 OPCode: call Value: 00000001 -ID: 77 OPCode: load Value: 000000000150ba40 -ID: 78 OPCode: call Value: 00000061 -ID: 79 OPCode: trunc Value: 61 -ID: 80 OPCode: store Value: 00000000 diff --git a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-0.txt b/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-0.txt deleted file mode 100644 index 6bb358eb..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-0.txt and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-1.txt b/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-1.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-1.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-2.txt b/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-2.txt deleted file mode 100644 index a88e40ba..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-2.txt and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-3.txt b/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-3.txt deleted file mode 100644 index 0fe9d426..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-3.txt and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-4.txt b/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-4.txt deleted file mode 100644 index b8b7b193..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/prog_output/output.0-4.txt and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-0 b/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-0 deleted file mode 100644 index cba20844..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-0 and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-1 b/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-1 deleted file mode 100644 index 3804ea3f..00000000 --- a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-1 +++ /dev/null @@ -1,4 +0,0 @@ -The content of input file is: this is a test program for wrong format in API!!!! -The content of buffer is: this is a test program for wrong format in API!!!! - the content of OUTPUT file is : -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-2 b/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-2 deleted file mode 100644 index 44f84c0a..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-2 and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-3 b/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-3 deleted file mode 100644 index 49397303..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-3 and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-4 b/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-4 deleted file mode 100644 index fac34c41..00000000 Binary files a/test_suite/Traces/BufferOverflow_API/llfi/std_output/std_outputfile-run-0-4 and /dev/null differ diff --git a/test_suite/Traces/BufferOverflow_API/sample.txt b/test_suite/Traces/BufferOverflow_API/sample.txt deleted file mode 100644 index 3061975b..00000000 --- a/test_suite/Traces/BufferOverflow_API/sample.txt +++ /dev/null @@ -1 +0,0 @@ -this is a test program for wrong format in API!!!! \ No newline at end of file diff --git a/test_suite/test_suite.yaml b/test_suite/test_suite.yaml index e4a82511..31e1ca1c 100644 --- a/test_suite/test_suite.yaml +++ b/test_suite/test_suite.yaml @@ -50,24 +50,6 @@ Traces: - llfi/llfi_stat_output/llfi.stat.trace.0-3.txt - llfi/llfi_stat_output/llfi.stat.trace.0-4.txt cdfg_prof: llfi.stat.graph.dot - BufferOverflow_API: - trace_prof: llfi/baseline/llfi.stat.trace.prof.txt - trace_inject: - - llfi/llfi_stat_output/llfi.stat.trace.0-0.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-1.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-2.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-3.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-4.txt - cdfg_prof: llfi.stat.graph.dot - BufferOverflowMemmove_Data: - trace_prof: llfi/baseline/llfi.stat.trace.prof.txt - trace_inject: - - llfi/llfi_stat_output/llfi.stat.trace.0-0.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-1.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-2.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-3.txt - - llfi/llfi_stat_output/llfi.stat.trace.0-4.txt - cdfg_prof: llfi.stat.graph.dot MakefileGeneration: normal_IR: diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index a9d5e3e5..b4004b3e 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -14,7 +14,4 @@ copy(GenerateMakefile.py GenerateMakefile) copy(zgrviewer/llfi_run.sh zgrviewer/run.sh) -#FIDL tests -copydir(FIDL/tests/ FIDL/tests/) - genCopy() diff --git a/tools/CompareLayerOutputs.py b/tools/CompareLayerOutputs.py index e503bda7..07fc004c 100644 --- a/tools/CompareLayerOutputs.py +++ b/tools/CompareLayerOutputs.py @@ -8,7 +8,6 @@ import json import warnings import sys, struct, math -from pdb import set_trace from math import floor, log10 import onnx from collections import OrderedDict @@ -17,6 +16,7 @@ ###### GLOBALS ######### + class DotGraph: Graph = None @@ -24,7 +24,7 @@ class DotGraph: def __init__(self, name): self.Graph = pyg.AGraph(strict=False, directed=True) - self.Graph.graph_attr["label"] = '< Model ' + name + ' >' + self.Graph.graph_attr["label"] = "< Model " + name + " >" self.Graph.graph_attr["compound"] = "true" self.Graph.node_attr["shape"] = "box" self.Graph.node_attr["fillcolor"] = "#EFE8E8" @@ -34,18 +34,17 @@ def addNodes(self, layers): self.Layers = layers for i in range(0, len(layers) - 1, 1): - self.Graph.add_edge(layers[i], layers[i+1]) - node = self.Graph.get_node(layers[i+1]) + self.Graph.add_edge(layers[i], layers[i + 1]) + node = self.Graph.get_node(layers[i + 1]) node.attr["style"] = "filled" node = self.Graph.get_node(layers[0]) node.attr["style"] = "filled" - def formatFloat(self, f): f = str(f) retval = "" - comp = f.split('e') + comp = f.split("e") retval = comp[0] if len(comp) == 2: retval = retval + "E" + comp[1] @@ -59,25 +58,34 @@ def addMisMatch(self, layer, index, elements, old_val, new_val): node.attr["fillcolor"] = "#EFE8E8" node.attr["style"] = "filled,dashed" - NameOfSubGraph = "cluster_"+layer+"_"+str(index) + NameOfSubGraph = "cluster_" + layer + "_" + str(index) old_val_node_name = "Old Value: \n" + self.formatFloat(old_val) + NameOfSubGraph new_val_node_name = "New Value: \n" + self.formatFloat(new_val) + NameOfSubGraph - self.Graph.add_node(old_val_node_name, label=str("Old Value: \n" + self.formatFloat(old_val))) - self.Graph.add_node(new_val_node_name, label=str("New Value: \n" + self.formatFloat(new_val))) + self.Graph.add_node( + old_val_node_name, label=str("Old Value: \n" + self.formatFloat(old_val)) + ) + self.Graph.add_node( + new_val_node_name, label=str("New Value: \n" + self.formatFloat(new_val)) + ) # Add mismatch as a subgraph self.Graph.add_edge(old_val_node_name, new_val_node_name, color="blue") self.Graph.add_edge(layer, old_val_node_name, lhead=NameOfSubGraph) - self.Graph.add_subgraph([old_val_node_name, new_val_node_name], name=NameOfSubGraph, color="blue", label='< Index ' + str(index) + ' / ' + str(elements) + ' >', style='dashed') + self.Graph.add_subgraph( + [old_val_node_name, new_val_node_name], + name=NameOfSubGraph, + color="blue", + label="< Index " + str(index) + " / " + str(elements) + " >", + style="dashed", + ) def printGraph(self): print(self.Graph.string()) def saveGraph(self): - file = open('mismatch.dot', 'w') - file.write(self.Graph.string()) - file.close() + with open("mismatch.dot", "w") as file: + file.write(self.Graph.string()) # Conatins all stuff related to LLTFI FI @@ -131,7 +139,7 @@ def setStatsDir(self, path): self.fi_runtime_stats_dir = path else: print("Invalid Directory path: " + str(path) + " Aborting!") - exit() + sys.exit(1) def argParser(self, argParser): @@ -145,9 +153,11 @@ def argParser(self, argParser): def reportMismatch(self, LayerId, d1_old, d2_new): if not self.is_new_file: - assert (self.fi_layer_id is not None) + assert self.fi_layer_id is not None if str(LayerId) == str(self.fi_layer_id): - self.fi_layer_op_mismatches[self.curr_file_name].append([d1_old, d2_new]) + self.fi_layer_op_mismatches[self.curr_file_name].append( + [d1_old, d2_new] + ) self.total_mismatches_found = self.total_mismatches_found + 1 else: pass @@ -160,22 +170,23 @@ def reportMismatch(self, LayerId, d1_old, d2_new): def getFIStats(self): self.total_runs = len(os.listdir(self.fi_runtime_stats_dir)) - for key in [k for k,v in self.fi_layer_op_mismatches.items() if len(v) > 0]: + for key in [k for k, v in self.fi_layer_op_mismatches.items() if len(v) > 0]: - id = key.split('.')[1] + id = key.split(".")[1] filename = "" for file in os.listdir(self.fi_runtime_stats_dir): if str(id) + ".txt" in file: filename = file - break; + break else: pass if filename == "": continue else: - f = open(str(self.fi_runtime_stats_dir + "/" + filename), "r").read() - bit = str((f.replace(" ", "").split(",")[-2]).split('=')[1]) + with open(str(self.fi_runtime_stats_dir + "/" + filename), "r") as _f: + f = _f.read() + bit = str((f.replace(" ", "").split(",")[-2]).split("=")[1]) temp = [bit] temp = temp + [k for k in self.fi_layer_op_mismatches[key]] @@ -192,7 +203,9 @@ def getFIStats(self): self.single_output_corruption = self.single_output_corruption + 1 else: if len(v) > 2: - self.multiple_output_corruption = self.multiple_output_corruption + 1 + self.multiple_output_corruption = ( + self.multiple_output_corruption + 1 + ) for data_val in v[1:]: d_old = float(data_val[0]) @@ -206,25 +219,28 @@ def getFIStats(self): else: self.single_bit_flips = self.single_bit_flips + 1 - # Check if 2 floats differ by multiple bit flips or just a single bit flip def isMultipleBitFlips(self, d1, d2): - if str(d1) == 'nan' or str(d2) == 'nan': + if str(d1) == "nan" or str(d2) == "nan": self.numbers_of_nan = self.numbers_of_nan + 1 return True h1 = float.hex(float(d1)) h2 = float.hex(float(d2)) - sign1 = h1.split('.')[0] - sign2 = h2.split('.')[0] - exp1 = h1.split('p')[-1] - exp2 = h2.split('p')[-1] - man1 = h1.split('.')[-1].split('p')[0] - man2 = h2.split('.')[-1].split('p')[0] + sign1 = h1.split(".")[0] + sign2 = h2.split(".")[0] + exp1 = h1.split("p")[-1] + exp2 = h2.split("p")[-1] + man1 = h1.split(".")[-1].split("p")[0] + man2 = h2.split(".")[-1].split("p")[0] # Count bit flips in sign, exponent and mantissa - numberOfFlips = self.countBitFlips(sign1, sign2) + self.countBitFlips(exp1, exp2) + self.countBitFlips(man1, man2) + numberOfFlips = ( + self.countBitFlips(sign1, sign2) + + self.countBitFlips(exp1, exp2) + + self.countBitFlips(man1, man2) + ) self.distance_bitflips.append(numberOfFlips) assert numberOfFlips > 0 @@ -251,15 +267,42 @@ def countBitFlips(self, h1, h2): def printSummary(self): print("Total number of files:" + str(self.total_runs)) - print("Total mismatches found in layer " + str(self.fi_layer_id) + " = " + str(self.total_mismatches_found)) + print( + "Total mismatches found in layer " + + str(self.fi_layer_id) + + " = " + + str(self.total_mismatches_found) + ) percentage = (float)(self.total_mismatches_found / self.total_runs) * 100 - print("""% of bitflips that corrupted the output of the FI layer: """ + str(percentage) + "%") - percentage = (float)(self.single_output_corruption / self.total_mismatches_found) * 100 - print("""% of times a single bitflip caused just a single corrpution in layer output: """ + str(percentage) + "%") - print( "FI in bits that cuased a sign flip: " + str(list(set(self.fi_bit_sign_flip))) ) - percentage = (float)(self.single_bit_flips / (self.single_bit_flips + self.multiple_bit_flips)) * 100 - print("% of single bit corruption in the output of the tensor operator: " + str(percentage) + "%") - print("Number of NaNs in output of the corrupted tensor operator: " + str(self.numbers_of_nan)) + print( + """% of bitflips that corrupted the output of the FI layer: """ + + str(percentage) + + "%" + ) + percentage = (float)( + self.single_output_corruption / self.total_mismatches_found + ) * 100 + print( + """% of times a single bitflip caused just a single corrpution in layer output: """ + + str(percentage) + + "%" + ) + print( + "FI in bits that cuased a sign flip: " + + str(list(set(self.fi_bit_sign_flip))) + ) + percentage = (float)( + self.single_bit_flips / (self.single_bit_flips + self.multiple_bit_flips) + ) * 100 + print( + "% of single bit corruption in the output of the tensor operator: " + + str(percentage) + + "%" + ) + print( + "Number of NaNs in output of the corrupted tensor operator: " + + str(self.numbers_of_nan) + ) def printStats(self): print("Mismatches in FI layer:" + str(self.fi_layer_op_mismatches)) @@ -273,6 +316,7 @@ def printDump(self): else: self.printStats() + # Global LLTFI Object LLTFIObject = LLTFI() @@ -288,22 +332,25 @@ def printDump(self): # ONNX model onnxModel = None + ###### HELPER FUNCTIONS ####### def printStructuralDifferenceError(): - print("Input JSON files are structurally difference. Abort!"); - exit() + print("Input JSON files are structurally difference. Abort!") + sys.exit(1) + def assertFun(a): if not a: - printStructuralDifferenceError(); + printStructuralDifferenceError() else: pass -def assertData(d1, d2, elements, layer, index, delta, op_layer_name = ""): + +def assertData(d1, d2, elements, layer, index, delta, op_layer_name=""): if abs(float(d1) - float(d2)) <= float(delta): return False else: - #print("Mismatch found at layer:" + layer + " index: " + index + " one value= " + d1 + " second value= " + d2) + # print("Mismatch found at layer:" + layer + " index: " + index + " one value= " + d1 + " second value= " + d2) mismatch.append([op_layer_name, layer, index, elements, float(d1), float(d2)]) global FIStatsCalculate @@ -324,7 +371,11 @@ def export_dot_graph(layerNames, GraphName, SelectedLayers): layer_name = layerNamesSmall[int(item[0])] - if ('all' in SelectedLayers) or (layer_name.split('_')[0] in SelectedLayers) or (layer_name is layerNamesSmall[-1]): + if ( + ("all" in SelectedLayers) + or (layer_name.split("_")[0] in SelectedLayers) + or (layer_name is layerNamesSmall[-1]) + ): g.addMisMatch(layerNames[int(item[0])], item[1], item[2], item[3], item[4]) g.saveGraph() @@ -340,41 +391,54 @@ def getJsonDiff(j1, j2, delta): jf = json.load(f) jg = json.load(g) - assertFun(len(jf) == len(jg)); + assertFun(len(jf) == len(jg)) # Iterate each layer for i in range(0, len(jf), 1): key = str(i) - #if not (jf[key]['Rank'] == jg[key]['Rank'] and jf[key]['Number of Elements'] == jg[key]['Number of Elements'] and jf[key]['Shape'] == jg[key]['Shape']): - #set_trace() + # if not (jf[key]['Rank'] == jg[key]['Rank'] and jf[key]['Number of Elements'] == jg[key]['Number of Elements'] and jf[key]['Shape'] == jg[key]['Shape']): + # set_trace() - if jf[key]['Rank'] == 0: - continue + if jf[key]["Rank"] == 0: + continue - #assertFun(jf[key]['Layer Id'] == jg[key]['Layer Id']) - assertFun(jf[key]['Rank'] == jg[key]['Rank']) - assertFun(jf[key]['Number of Elements'] == jg[key]['Number of Elements']) - assertFun(jf[key]['Shape'] == jg[key]['Shape']) + # assertFun(jf[key]['Layer Id'] == jg[key]['Layer Id']) + assertFun(jf[key]["Rank"] == jg[key]["Rank"]) + assertFun( + jf[key]["Number of Elements"] == jg[key]["Number of Elements"] + ) + assertFun(jf[key]["Shape"] == jg[key]["Shape"]) - layerId = jf[key]['Layer Id'] + layerId = jf[key]["Layer Id"] op_name = "" if isNLPModel: layerId = i global onnxModel import onnx + m = onnx.load(onnxModel) op_name = m.graph.output[i].name - #Iterate all data outputs - for j in range(0, len(jf[key]['Data']), 1): - retval = assertData(str(jf[key]['Data'][j]), str(jg[key]['Data'][j]), str(jf[key]['Number of Elements']), str(layerId), str(j), delta, op_name) - except: + # Iterate all data outputs + for j in range(0, len(jf[key]["Data"]), 1): + retval = assertData( + str(jf[key]["Data"][j]), + str(jg[key]["Data"][j]), + str(jf[key]["Number of Elements"]), + str(layerId), + str(j), + delta, + op_name, + ) + except Exception: pass + # Get total number of layers in a model layers_in_json_file = None + def getTotalLayers(j1): global layers_in_json_file @@ -382,7 +446,7 @@ def getTotalLayers(j1): if layers_in_json_file is None: with open(j1, "r") as f: jf = json.load(f) - layers_in_json_file = int(jf[str(len(jf) - 1)]['Layer Id']) + layers_in_json_file = int(jf[str(len(jf) - 1)]["Layer Id"]) return layers_in_json_file @@ -395,15 +459,65 @@ def main(): parser.add_argument("second") parser.add_argument("ONNXModel") - parser.add_argument("-d", "--delta", action="store", type=str, default="0", help="What should be the minimum difference between the mismatches? \t Default: 0.0") - parser.add_argument("-f", "--folder", action="store_true", default=False, help="Specify this flag if the second argument is a folder path? \t Default: False") - parser.add_argument("--dot", action="store_true", default=False, help="Do you want to output a dot file? \t Default: False") - parser.add_argument("--nlp", action="store_true", default=False, help="Is the model a NLP model? \t Default: False") - parser.add_argument("-fdiff", "--print_diff_final_op", action="store_true", default=False, help="Print mismatch on terminal only when there is a mismatch in the final layer o/p. \t Default: False") - parser.add_argument("--selected_layers_in_dot", action="store", default="all", help="""Show only mismatches in selected layers during DOT file generation. Semicolon seperate values. \t Default: "all" """) - parser.add_argument("--summary", action="store_true", default=False, help="Do you just want to output a summary of mismatches? \t Default: False") - parser.add_argument("--getFIStatsRQ2", action="store_true", default=False, help="Should I compute FI stats? \t Default: False") - parser.add_argument("--fiStatsDir", action="store", type=str, default="None", help="Directory path containg logs produced by LLTFI at runtime. \t Default: None") + parser.add_argument( + "-d", + "--delta", + action="store", + type=str, + default="0", + help="What should be the minimum difference between the mismatches? \t Default: 0.0", + ) + parser.add_argument( + "-f", + "--folder", + action="store_true", + default=False, + help="Specify this flag if the second argument is a folder path? \t Default: False", + ) + parser.add_argument( + "--dot", + action="store_true", + default=False, + help="Do you want to output a dot file? \t Default: False", + ) + parser.add_argument( + "--nlp", + action="store_true", + default=False, + help="Is the model a NLP model? \t Default: False", + ) + parser.add_argument( + "-fdiff", + "--print_diff_final_op", + action="store_true", + default=False, + help="Print mismatch on terminal only when there is a mismatch in the final layer o/p. \t Default: False", + ) + parser.add_argument( + "--selected_layers_in_dot", + action="store", + default="all", + help="""Show only mismatches in selected layers during DOT file generation. Semicolon seperate values. \t Default: "all" """, + ) + parser.add_argument( + "--summary", + action="store_true", + default=False, + help="Do you just want to output a summary of mismatches? \t Default: False", + ) + parser.add_argument( + "--getFIStatsRQ2", + action="store_true", + default=False, + help="Should I compute FI stats? \t Default: False", + ) + parser.add_argument( + "--fiStatsDir", + action="store", + type=str, + default="None", + help="Directory path containg logs produced by LLTFI at runtime. \t Default: None", + ) args = parser.parse_args() @@ -445,13 +559,13 @@ def main(): if args.print_diff_final_op: for key, value in list(global_mismatch.items()): - mismatch_found = False; + mismatch_found = False for mismatch in value: if mismatch[0] == str(getTotalLayers(args.first)): mismatch_found = True - break; + break # Remove the mismatches where the last layer doesn't differ if not mismatch_found: @@ -461,16 +575,23 @@ def main(): # Print summary only if args.summary: - print("Mismatches found in " + str(len([key for key in global_mismatch.keys()])) + " file(s).") + print( + "Mismatches found in " + + str(len([key for key in global_mismatch.keys()])) + + " file(s)." + ) else: - print("Mismatch found in: " + str([key for key in global_mismatch.keys()])) + print( + "Mismatch found in: " + + str([key for key in global_mismatch.keys()]) + ) else: print("No mismatch found.") else: print("Invalid directory path: " + args.second) - exit() + sys.exit(1) else: if os.path.isfile(args.first) and os.path.isfile(args.second): @@ -494,7 +615,7 @@ def main(): print("No mismatch found.") else: print("Invalid file path(s): " + args.first + " \n " + args.second) - exit() + sys.exit(1) # Export mismatches as a dot file. Don't output the dot file when folder is given as an input. if args.dot and (not args.folder): @@ -503,18 +624,21 @@ def main(): output_names = [] # Output Layer names for i in range(0, len(model.graph.node), 1): - output_names.append(str(model.graph.node[i].op_type) + "_" + str(i+1)) + output_names.append(str(model.graph.node[i].op_type) + "_" + str(i + 1)) # Now export a dot graph showing the layer structure and mismatch points - export_dot_graph(output_names, model.graph.doc_string, args.selected_layers_in_dot.split(';')) + export_dot_graph( + output_names, model.graph.doc_string, args.selected_layers_in_dot.split(";") + ) if args.getFIStatsRQ2: LLTFIObject.printDump() + def main_deprecated(): warnings.warn("jsondiff is deprecated. Use jdiff instead.", DeprecationWarning) main() -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/tools/ExtendONNXModel.py b/tools/ExtendONNXModel.py index cb655398..623239c3 100644 --- a/tools/ExtendONNXModel.py +++ b/tools/ExtendONNXModel.py @@ -6,15 +6,33 @@ from collections import OrderedDict parser = argparse.ArgumentParser() -parser.add_argument('--model_path', type=str, required=True, help="Path to the ONNX model") -parser.add_argument('--output_model_path', type=str, required=True, help="Path to the extended ONNX model") -parser.add_argument('--layers', type=str, default="all", help="""Intermediate layers for which output has to be extracted. Semicolon seperated. like: --layers="conv;maxpool" """) -parser.add_argument('--summary', default=False, action="store_true", help="Just return the summary of the model") +parser.add_argument( + "--model_path", type=str, required=True, help="Path to the ONNX model" +) +parser.add_argument( + "--output_model_path", + type=str, + required=True, + help="Path to the extended ONNX model", +) +parser.add_argument( + "--layers", + type=str, + default="all", + help="""Intermediate layers for which output has to be extracted. Semicolon seperated. like: --layers="conv;maxpool" """, +) +parser.add_argument( + "--summary", + default=False, + action="store_true", + help="Just return the summary of the model", +) args = parser.parse_args() layers = [] all_layers = False + def extend_model_output(model, intermediate_outputs): global layers # onnx-mlir doesn't care about manually specified output types & shapes. @@ -23,43 +41,54 @@ def extend_model_output(model, intermediate_outputs): original_op = model.graph.output[0].name # Remove all outputs from the model - while (len(model.graph.output)): + while len(model.graph.output): model.graph.output.pop() - i = -1; + i = -1 layer_output = [] for output_name in intermediate_outputs: - i = i + 1 # Current layer - if all_layers and ('Constant' not in output_name) and ('Gather' not in output_name): - output_value_info = onnx.helper.make_tensor_value_info(output_name, DUMMY_TENSOR_TYPE, None) + i = i + 1 # Current layer + if ( + all_layers + and ("Constant" not in output_name) + and ("Gather" not in output_name) + ): + output_value_info = onnx.helper.make_tensor_value_info( + output_name, DUMMY_TENSOR_TYPE, None + ) model.graph.output.extend([output_value_info]) layer_output.append(i) else: temp = [n for n in layers if n in output_name] if len(temp) > 0: - output_value_info = onnx.helper.make_tensor_value_info(output_name, DUMMY_TENSOR_TYPE, None) + output_value_info = onnx.helper.make_tensor_value_info( + output_name, DUMMY_TENSOR_TYPE, None + ) model.graph.output.extend([output_value_info]) layer_output.append(i) - - if (not all_layers) and ( (len(intermediate_outputs) - 1) not in layer_output): + if (not all_layers) and ((len(intermediate_outputs) - 1) not in layer_output): # Add original output - output_value_info = onnx.helper.make_tensor_value_info(original_op, DUMMY_TENSOR_TYPE, None) + output_value_info = onnx.helper.make_tensor_value_info( + original_op, DUMMY_TENSOR_TYPE, None + ) model.graph.output.extend([output_value_info]) layer_output.append(len(intermediate_outputs) - 1) return model, layer_output + def print_summary(model): - print(str([[n for n in node.output if n != ''] for node in model.graph.node])) + print(str([[n for n in node.output if n != ""] for node in model.graph.node])) + def main(): global layers global all_layers - layers = (args.layers).split(';') + layers = (args.layers).split(";") - if 'all' in layers: + if "all" in layers: all_layers = True # Load the onnx model. @@ -67,17 +96,19 @@ def main(): if args.summary: print_summary(model) - exit() + sys.exit(0) output_names = [o.name for o in model.graph.output] output_names = list(OrderedDict.fromkeys(output_names)) - output_names = sum([[n for n in node.output if n != ''] for node in model.graph.node], []) + output_names = sum( + [[n for n in node.output if n != ""] for node in model.graph.node], [] + ) output_names = list(OrderedDict.fromkeys(output_names)) model, layer_output = extend_model_output(model, output_names) onnx.save(model, args.output_model_path) - print(str(layer_output).strip('[').strip(']').replace(' ', '')) + print(str(layer_output).strip("[").strip("]").replace(" ", "")) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/tools/FIDL/FIDL-Algorithm.py b/tools/FIDL/FIDL-Algorithm.py deleted file mode 100755 index d902fcc6..00000000 --- a/tools/FIDL/FIDL-Algorithm.py +++ /dev/null @@ -1,642 +0,0 @@ -#! /usr/bin/env python3 - -""" -%(prog)s takes a FIDL (Fault Injection Description Language) yaml and -generates an instruction/register selector C++ code, and a fault injection -run-time C++ code. - -Usage: %(prog)s [OPTIONS] - -List of options: --a : add a FI run-time and selector from a FIDL yaml --r : removes the specified injector by '()' - or remove all 'custom' or 'default' injector --l : lists all active injectors/selectors by 'custom' or 'default' --h : shows help - -Every time the content of a FIDL yaml is changed, this script should be executed -(-a ) to reflect the change(s) in the generated C++ code. Rebuild -LLFI after adding or changing a FIDL yaml to use the new/changed injector(s). - -Failure Class and Failure Mode pair should be unique, otherwise the previous -Failure Class and Failure Mode pair is overwritten. - -It is assumed that the script is located in the '/tools/FIDL/' -directory. - -For more information, see: -https://github.com/DependableSystemsLab/LLFI/wiki/Using-FIDL-to-create-a-Custom-Software-Fault-Injector-and-a-Custom-Instruction-Selector -""" - -import sys, os, subprocess -import time -import yaml - -################################################################################ - -script_dir = os.path.realpath(os.path.dirname(__file__)) -prog = os.path.basename(sys.argv[0]) - -llfiroot = os.path.dirname(os.path.dirname(script_dir)) - -fidl_runtime_path = os.path.join(llfiroot, 'runtime_lib/_FIDLSoftwareFaultInjectors.cpp') -software_failures_passes_dir = os.path.join(llfiroot, 'llvm_passes/software_failures/') -cmakelists = os.path.join(llfiroot, 'llvm_passes/CMakeLists.txt') - -setup_script = os.path.join(llfiroot, 'setup.py') - -config_dir = os.path.join(script_dir, 'config') - -all_injectors_yaml = os.path.join(config_dir, 'injectors.yaml') -default_failures_yaml = os.path.join(config_dir, 'default_failures.yaml') - -injector_template = os.path.join(config_dir, 'NewInjectorTemplate.cpp') -single_template = os.path.join(config_dir, 'TargetSingleTemplate.cpp') -all_template = os.path.join(config_dir, 'TargetAllTemplate.cpp') -multisrc_template = os.path.join(config_dir, 'TargetMultiSourceTemplate.cpp') - -################################################################################ - -def read_file(file_name): - with open(file_name) as f: - lines = f.read().splitlines() - return lines - -def write_file(file_name, lines): - with open(file_name, 'w') as f: - for line in lines: - f.write('%s\n' % line) - -def write_yaml(obj, path): - f = open(path, 'w') - yaml.dump(obj, f) - f.close() - -def read_input_yaml(filename): - # Check for Input FIDL's presence - try: - f = open(filename, 'r') - except: - print('Error: Specified FIDL config file (%s) not found!' % (filename), file = sys.stderr) - exit(1) - - # Check for correct YAML formatting - try: - doc = yaml.safe_load(f) - f.close() - except: - print('Error: %s is not formatted in proper YAML format (reminder: use spaces, not tabs)' % (filename), file = sys.stderr) - exit(1) - - return doc - -def parse_input(doc): - options = {} - # Load values and check if FIDL options are valid - try: - options['f_class'] = doc['Failure_Class'] - options['f_mode'] = doc['Failure_Mode'] - nfm = doc['New_Failure_Mode'] - - # parses 'Trigger' - trigger = nfm['Trigger'] - - if 'call' in trigger or 'call*' in trigger: - if 'call' in trigger: - options['insts'] = trigger_insts = trigger['call'] - options['trigger_type'] = 'call' - else: - options['insts'] = trigger_insts = trigger['call*'] - options['trigger_type'] = 'call*' - if 'all' in trigger_insts: - raise Exception("Error: Cannot instrument both 'call*' and 'all'!") - - # parses 'Target' - target = nfm['Target'] - if 'src' in target and 'dst' in target: - raise Exception('Error: Invalid trigger module (both src and dst usage not allowed)') - elif 'src' in target: - target_insts = target['src'] - if (set(target_insts) != set(trigger_insts) or # need to specify at least one src for each instruction - bool([target_inst for inst in target_insts.values() if inst == None or inst == [] or inst == '' or not isinstance(inst, list)])): # check that the specified src's aren't empty, an empty list, or an empty string, or isn't a list - raise Exception("Error: Invalid number/name of src's in Target, or Target sources are not specified as list!") - - if 'all' in target_insts and len(target_insts) != 1: - raise Exception('Error: When instrumenting all call instruction, only 1 register can be specified!') - - options['insts'] = target_insts # replace with target insts if we're dealing with src - options['reg_type'] = 'src' - elif 'dst' in target: - options['reg_type'] = 'dst' - else: - raise Exception('Error: Invalid register target type!') - elif 'return' in trigger: - options['trigger_type'] = 'return' - options['reg_type'] = 'ret' # unused - options['insts'] = [] # unused - else: - raise Exception('Error: Trigger option (call, call*, or ret) not found!') - - # parses 'Trigger*' - if 'Trigger*' in nfm: - options['trigger_s'] = nfm['Trigger*'] - - # parses 'Action' - options['action'] = action = nfm['Action'] - if 'Perturb' in action: - perturb = action['Perturb'] - if 'Custom_Injector' in perturb: - options['custom_injector'] = doc['Custom_Injector'] - - # raises any errors occured during parsing - except Exception as e: - raise e - - # {f_class, f_mode, trigger_type, reg_type, action, injector* (from gen_runtime_code())} - # {insts* (if src selected), trigger_s* (if indices selected), custom_injector* (if implemented)} - return options - -################################################################################ - -def gen_ftrigger_single(options): - f_class = options['f_class'] - f_mode = options['f_mode'] - trigger_type = options['trigger_type'] - injector = options['injector'] - insts = options['insts'] - reg_type = options['reg_type'] - - # convert trigger and target of .fidl file into appropriate llvm passes - lines = read_file(single_template) - - add_current_time(lines) # attaches a generated date/time to the file - - i = lines.index('//fidl_1') - lines.insert(i + 1, 'class _%s_%sInstSelector : public SoftwareFIInstSelector {' % (f_class, f_mode)) - - i = lines.index('//fidl_2') - lines.insert(i + 1, ' _%s_%sInstSelector() {' % (f_class, f_mode)) - - i = lines.index('//fidl_3') - for n in insts: - lines.insert(i + 1, ' funcNames.insert(std::string("%s"));' % n) - - i = lines.index('//fidl_4') - lines.insert(i + 1, ' info["failure_class"] = "%s";' % (f_class)) - lines.insert(i + 2, ' info["failure_mode"] = "%s";' % (f_mode)) - - lines.insert(lines.index('//fidl_5') + 1, ' info["injector"] = "%s";' % (injector)) - - i = lines.index('//fidl_6') - insert = '' - if 'trigger_s' in options: - insert = ' && isTargetLLFIIndex(inst)' - if trigger_type != 'call*': - lines.insert(i + 1, ' return funcNames.find(func_name) != funcNames.end()%s;' % (insert)) - else: - lines.insert(i + 1, ' return key_partially_matches(func_name) != funcNames.end()%s;' % (insert)) - - lines.insert(lines.index('//fidl_7') + 1, gen_targeted_indices(options)) - - lines.append('std::set _%s_%sInstSelector::funcNames;\n' % (f_class, f_mode)) - - lines.append('static RegisterFIInstSelector A("%s(%s)", new _%s_%sInstSelector());' % (f_mode, f_class, f_class, f_mode)) - # change reg_type - if reg_type == 'src': - lines.append('static RegisterFIRegSelector B("%s(%s)", new FuncArgRegSelector(%s));\n\n}\n' % (f_mode, f_class, next(iter(insts.values()))[0])) - elif reg_type == 'dst': - lines.append('static RegisterFIRegSelector B("%s(%s)", new FuncDestRegSelector());\n\n}\n' % (f_mode, f_class)) - - return lines - -def gen_ftrigger_all(options): - f_class = options['f_class'] - f_mode = options['f_mode'] - trigger_type = options['trigger_type'] - injector = options['injector'] - insts = options['insts'] - reg_type = options['reg_type'] - - lines = read_file(all_template) - - add_current_time(lines) # attaches a generated date/time to the file - - i = lines.index('//fidl_1') - lines.insert(i + 1, 'class _%s_%sInstSelector : public SoftwareFIInstSelector {' % (f_class, f_mode)) - - i = lines.index('//fidl_2') - lines.insert(i + 1, ' info["failure_class"] = "%s";' % (f_class)) - lines.insert(i + 2, ' info["failure_mode"] = "%s";' % (f_mode)) - lines.insert(i + 3, ' info["injector"] = "%s";' % (injector)) - if trigger_type == 'return': - lines.insert(i + 4, ' info["targets"] = "return";') - else: - lines.insert(i + 4, ' info["targets"] = "all-call-instructions";') - - i = lines.index('//fidl_3') - insert = '' - if 'trigger_s' in options: - insert = ' && isTargetLLFIIndex(inst)' - if trigger_type == 'return': - lines.insert(i + 1, ' return isa(inst)%s;' % (insert)) - else: - lines.insert(i + 1, ' return isa(inst)%s;' % (insert)) - - lines.insert(lines.index('//fidl_4') + 1, gen_targeted_indices(options)) - - lines.append('static RegisterFIInstSelector A("%s(%s)", new _%s_%sInstSelector());' % (f_mode, f_class, f_class, f_mode)) - if trigger_type == 'return': - lines.append('static RegisterFIRegSelector B("%s(%s)", new RetValRegSelector());\n\n}\n' % (f_mode, f_class)) - else: - if reg_type == 'src': - lines.append('static RegisterFIRegSelector B("%s(%s)", new FuncArgRegSelector(%s));\n\n}\n' % (f_mode, f_class, next(iter(insts.values()))[0])) - elif reg_type == 'dst': - lines.append('static RegisterFIRegSelector B("%s(%s)", new FuncDestRegSelector());\n\n}\n' % (f_mode, f_class)) - - return lines - -def gen_ftrigger_multisrc(options): - f_class = options['f_class'] - f_mode = options['f_mode'] - trigger_type = options['trigger_type'] - insts = options['insts'] - injector = options['injector'] - - lines = read_file(multisrc_template) - - add_current_time(lines) # attaches a generated date/time to the file - - i = lines.index('//fidl_1') - lines.insert(i + 1, 'class _%s_%sInstSelector : public SoftwareFIInstSelector {' % (f_class, f_mode)) # Trigger: "fread" - - i = lines.index('//fidl_2') - lines.insert(i + 1,' _%s_%sInstSelector () {' % (f_class,f_mode)) - - i = lines.index('//fidl_3') - for inst in insts: - lines.insert(i + 1, ' funcNamesTargetArgs["%s"] = std::set();' % inst) - for reg in insts[inst]: - lines.insert(i + 2, ' funcNamesTargetArgs["%s"].insert(%s);' % (inst, reg)) - - i = lines.index('//fidl_4') - lines.insert(i + 1, ' info["failure_class"] = "%s";' % (f_class)) - lines.insert(i + 2, ' info["failure_mode"] = "%s";' % (f_mode)) - - lines.insert(lines.index('//fidl_5') + 1, ' info["injector"] = "%s";' % (injector)) - - i = lines.index('//fidl_6') - if trigger_type != 'call*': - lines.insert(i + 1, ' if (funcNamesTargetArgs.find(func_name) == funcNamesTargetArgs.end()) {') - lines.insert(i + 2, ' return false;') - lines.insert(i + 3, ' }') - lines.insert(i + 4, ' for (std::set::iterator SI = funcNamesTargetArgs[func_name].begin(); SI != funcNamesTargetArgs[func_name].end(); SI++) {') - else: - lines.insert(i + 1, ' std::map >::iterator it = key_partially_matches(func_name);') - lines.insert(i + 2, ' if (it == funcNamesTargetArgs.end()) {') - lines.insert(i + 3, ' return false;') - lines.insert(i + 4, ' }') - lines.insert(i + 5, ' for (std::set::iterator SI = it->second.begin(); SI != it->second.end(); SI++) {') - - i = lines.index('//fidl_7') - insert = '' - if 'trigger_s' in options: - insert = ' && isTargetLLFIIndex(inst)' - if trigger_type != 'call*': - lines.insert(i + 1, ' return funcNamesTargetArgs.find(func_name) != funcNamesTargetArgs.end()%s;' % (insert)) - else: - lines.insert(i + 1, ' return key_partially_matches(func_name) != funcNamesTargetArgs.end()%s;' % (insert)) - - lines.insert(lines.index('//fidl_8') + 1, gen_targeted_indices(options)) - - i = lines.index('//fidl_9') - lines.insert(i + 1, 'std::map > _%s_%sInstSelector::funcNamesTargetArgs;\n' % (f_class, f_mode)) - lines.insert(i + 2, 'class _%s_%sRegSelector: public SoftwareFIRegSelector {' % (f_class, f_mode)) - - lines.insert(lines.index('//fidl_10') + 1, ' if (_%s_%sInstSelector::isTarget(CI, reg)) {\n return true;' % (f_class, f_mode)) - - lines.append('static RegisterFIInstSelector A("%s(%s)", new _%s_%sInstSelector());' % (f_mode, f_class, f_class, f_mode)) - lines.append('static RegisterFIRegSelector B("%s(%s)", new _%s_%sRegSelector());\n\n}\n' % (f_mode, f_class, f_class, f_mode)) - - return lines - -def gen_and_write_selector(options) : - f_class = options['f_class'] - f_mode = options['f_mode'] - - trigger_type = options['trigger_type'] - reg_type = options['reg_type'] - insts = options['insts'] - - # complete instrumenting pass development by printing the pass content into a file. - # write to a file - filename = '_%s_%sSelector.cpp' % (f_class, f_mode) - filepath = os.path.join(software_failures_passes_dir, filename) - - if trigger_type == 'return' or 'all' in insts: # targering all call instructions or all 'ret' - write_file(filepath, gen_ftrigger_all(options)) - elif reg_type == 'src' and not is_one_src_register(insts): # multisrc - write_file(filepath, gen_ftrigger_multisrc(options)) - elif reg_type == 'src' or reg_type == 'dst': # dst or singlesrc - write_file(filepath, gen_ftrigger_single(options)) - else: - print('Error: Invalid FIDL config!', file = sys.stderr) - exit(1) - - # modify llvm_pass/CMakeLists.txt - l = read_file(cmakelists) - try: - l.index(' software_failures/%s' % filename) - except: - l.insert(l.index(" #FIDL - DO NOT MODIFY UNTIL '#END'") + 1, ' software_failures/%s' % filename) - write_file(cmakelists, l) - - print('Instrument module created.') - -# checks if we are only instrumenting a single src register -def is_one_src_register(insts): - init_val = next(iter(insts.values()))[0] - for inst in insts.values(): - if len(inst) > 1 or inst[0] != init_val: - return False - - return True - -def gen_targeted_indices(options): - if 'trigger_s' in options: - trigger_s = options['trigger_s'] - targeted_indices = ', '.join(str(s) for s in trigger_s) - n = len(trigger_s) - else: - targeted_indices = '' - n = 0 - - return ' const long n = %s;\n' % (n) + \ - ' const long targeted_indices[] = {%s};' % (targeted_indices) - -def add_current_time(lines): - lines[1] += time.strftime('%Y/%m/%d %H:%M:%S %Z') # yyyy/mm/dd 24hrClock Timezone - -################################################################################ - -def gen_runtime_code(options, injectors_dict): - f_class = options['f_class'] - f_mode = options['f_mode'] - action = options['action'] - - name = '%s(%s)' % (f_mode, f_class) - selectorfilename = '_%s_%sSelector.cpp' % (f_class, f_mode) - code = [] - injector = '' # for use in info['injector'] in the selector file - - insert = '_%s_%sFIDLInjector("%s(%s)",' % (f_class, f_mode, f_mode, f_class) - - if 'Corrupt' in action: - injector = 'BitCorruptionInjector'; - code.append('static RegisterFaultInjector %s BitCorruptionInjector::getBitCorruptionInjector());' % (insert)) - elif 'Freeze' in action: - injector = 'HangInjector' - code.append('static RegisterFaultInjector %s new HangInjector());' % (insert)) - elif 'Delay' in action: - injector = 'SleepInjector' - code.append('static RegisterFaultInjector %s new SleepInjector());' % (insert)) - elif 'Perturb' in action: - perturb = action['Perturb'] - - # certain perturb actions needs more information - if 'option' in action: - if action['option']: - boolean = 'true' - else: - boolean = 'false' - if 'value' in action: - value = action['value'] - - if 'MemoryLeakInjector' in perturb: - injector = 'MemoryLeakInjector' - code.append('static RegisterFaultInjector %s new MemoryLeakInjector());' % (insert)) - elif 'ChangeValueInjector' in perturb: - injector = 'ChangeValueInjector' - try: - code.append('static RegisterFaultInjector %s new ChangeValueInjector(%s, %s));' % (insert, value, boolean)) - except NameError: - print("Error: 'Perturb: %s' injector requires a integer value under 'Action: value:', and a boolean option under 'Action: option:'!" % (injector), file = sys.stderr) - exit(1) - elif 'InappropriateCloseInjector' in perturb: - injector = 'InappropriateCloseInjector' - try: - code.append('static RegisterFaultInjector %s new InappropriateCloseInjector(%s));' % (insert, boolean)) - except NameError: - print("Error: 'Perturb: %s' injector requires a boolean option under 'Action: option:'!" % (injector), file = sys.stderr) - exit(1) - elif 'MemoryExhaustionInjector' in perturb: - injector = 'MemoryExhaustionInjector' - try: - code.append('static RegisterFaultInjector %s new MemoryExhaustionInjector(%s));' % (insert, boolean)) - except NameError: - print("Error: 'Perturb: %s' injector requires a boolean option under 'Action: option:'!" % (injector), file = sys.stderr) - exit(1) - elif 'WrongFormatInjector' in perturb: - injector = 'WrongFormatInjector' - code.append('static RegisterFaultInjector %s new WrongFormatInjector());' % (insert)) - elif 'PthreadDeadLockInjector' in perturb: - injector = 'PthreadDeadLockInjector' - code.append('static RegisterFaultInjector %s new PthreadDeadLockInjector());' % (insert)) - elif 'PthreadThreadKillerInjector' in perturb: - injector = 'PthreadThreadKillerInjector' - code.append('static RegisterFaultInjector %s new PthreadThreadKillerInjector());' % (insert)) - elif 'PthreadRaceConditionInjector' in perturb: - injector = 'PthreadRaceConditionInjector' - code.append('static RegisterFaultInjector %s new PthreadRaceConditionInjector());' % (insert)) - elif 'StalePointerInjector' in perturb: - injector = 'StalePointerInjector' - code.append('static RegisterFaultInjector %s new StalePointerInjector());' % (insert)) - elif 'Custom_Injector' in perturb: - injector = 'CustomInjector' - if 'custom_injector' in options: - code.extend(gen_custom_injector(insert, f_class, f_mode, options['custom_injector'])) - else: - print("Error: Custom_Injector specified in 'Perturb:' but not specified in .yaml file!", file = sys.stderr) - exit(1) - else: - print('Error: Invalid Perturb Injector!', file = sys.stderr) - exit(1) - else: - print("Error: Invalid 'Action:' field in yaml file!", file = sys.stderr) - exit(1) - - options['injector'] = injector - injectors_dict[name] = {'selectorfilename': selectorfilename, 'code': '\n'.join(code)} - -def gen_runtime_file(injectors): - content = [] - content.append('// DO NOT MODIFY\n#include "_SoftwareFaultInjectors.cpp"\n') - - content.append('/*********************') - content.append(' * DEFAULT INJECTORS *') - content.append(' *********************/\n') - - for key, value in injectors['default'].items(): - content.append('// ' + key) - content.append(value['code'] + '\n') - - content.append('/********************') - content.append(' * CUSTOM INJECTORS *') - content.append(' ********************/\n') - - for key, value in injectors['custom'].items(): - content.append('// ' + key) - content.append(value['code'] + '\n') - - write_file(fidl_runtime_path, content) - -def gen_custom_injector(insert, f_class, f_mode, custom_injector): - # format the custom injector lines - custom_injector = ' ' + custom_injector # add spaces before the first line - custom_injector = custom_injector.rstrip('\n') # remove last \n character - custom_injector = custom_injector.replace('\n', '\n ') # add spaces after every \n character - - # read template - lines = read_file(injector_template) - - # modify template - lines[0] = 'class _%s_%sFInjector : public SoftwareFaultInjector {' % (f_class, f_mode) - lines[4] = custom_injector - lines.append('static RegisterFaultInjector %s new _%s_%sFInjector());' % (insert, f_class, f_mode)) - - return lines - -################################################################################ - -# modify llvm_pass/CMakeLists.txt and remove the selector file -def del_selectors(selectorfilenames): - l = read_file(cmakelists) - for n in selectorfilenames: - try: - l.remove(' software_failures/%s' % n) - except Exception: - pass - try: - os.remove(os.path.join(software_failures_passes_dir, n)) - except Exception: - pass - write_file(cmakelists, l) - -def list_injectors(injector_type): - all_injectors = read_input_yaml(all_injectors_yaml) - injectors = all_injectors[injector_type] - - if injectors == {}: - print('No injector exists!') - else: - for key in injectors: - print(key) - -def add_injectors(filename, injector_type): - inp = read_input_yaml(filename) - - all_injectors = read_input_yaml(all_injectors_yaml) - injectors = all_injectors[injector_type] - - if isinstance(inp, list): - for n in inp: - options = parse_input(n) # parses yaml - print('Generating %s(%s) %s software injector.' % (options['f_mode'], options['f_class'], injector_type)) - gen_runtime_code(options, injectors) # generates runtime fault injector code and insert runtime into a dictionary - gen_and_write_selector(options) # generates selector file - else: - options = parse_input(inp) # parses yaml - print('Generating %s(%s) %s software injector.' % (options['f_mode'], options['f_class'], injector_type)) - gen_runtime_code(options, injectors) # generates runtime fault injector code and insert runtime into a dictionary - gen_and_write_selector(options) # generates selector file - - write_yaml(all_injectors, all_injectors_yaml) # writes runtime into a yaml/storage file - - gen_runtime_file(all_injectors) # generate the actual .cpp runtime file - -def del_injector(name, injector_type): - all_injectors = read_input_yaml(all_injectors_yaml) - injectors = all_injectors[injector_type] - - if name in injectors: - del_selectors([injectors[name]['selectorfilename']]) - del injectors[name] - else: - print('Error: %s is not a %s injector!' % (name, injector_type), file = sys.stderr) - exit(1) - - write_yaml(all_injectors, all_injectors_yaml) - - gen_runtime_file(all_injectors) - -def del_injectors(injector_type): - all_injectors = read_input_yaml(all_injectors_yaml) - injectors = all_injectors[injector_type] - - selectorfilenames = [] - for _, value in injectors.items(): - selectorfilenames.append(value['selectorfilename']) - del_selectors(selectorfilenames) - - all_injectors[injector_type] = {} - write_yaml(all_injectors, all_injectors_yaml) - - gen_runtime_file(all_injectors) - -################################################################################ - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = 'Error: ' + msg - print(msg, file = sys.stderr) - print(__doc__ % globals(), file = sys.stderr) - sys.exit(retval) - -def parse_args(args): - option = None - name = None - - if len(args) == 0: - usage() - elif len(args) == 1: - option = args[0] - elif len(args) == 2: - option = args[0] - name = args[1] - else: - usage('Invalid Input!') - - return (option, name) - -def main(args): - option, name = parse_args(args) - if option == '-r': - if name == 'custom' or name == 'default': - print('Deleting all %s injectors!' % (name)) - del_injectors(name) - else: - print('Deleting %s' % (name)) - del_injector(name, 'custom') - - elif option == '-l': - if name == 'custom' or name == 'default': - print('Current %s FIDL software fault injectors:' % (name)) - list_injectors(name) - else: - usage('%s is not a valid injector type!' % (name)) - - elif option == '-a': - if name == 'default': - add_injectors(default_failures_yaml, name) - else: - add_injectors(name, 'custom') - print('Injector module(s) created.') - elif option == '-h': - usage() - else: - usage('%s is not a valid option!' % (option)) - -if __name__ == '__main__': - main(sys.argv[1:]) - diff --git a/tools/FIDL/config/NewInjectorTemplate.cpp b/tools/FIDL/config/NewInjectorTemplate.cpp deleted file mode 100644 index d9a253ab..00000000 --- a/tools/FIDL/config/NewInjectorTemplate.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// DO NOT MODIFY! - public: - virtual void injectFault(long llfi_index, unsigned size, unsigned fi_bit, char *buf) { - // Write your code here! - - // END - } -}; - diff --git a/tools/FIDL/config/TargetAllTemplate.cpp b/tools/FIDL/config/TargetAllTemplate.cpp deleted file mode 100644 index 3db3bb79..00000000 --- a/tools/FIDL/config/TargetAllTemplate.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// DO NOT MODIFY! -// File generated on - -// This file was generated from /tools/FIDL/TargetAllTemplate.cpp -// by the /tools/FIDL/FIDL-Algorithm.py -// See https://github.com/DependableSystemsLab/LLFI/wiki/Using-FIDL-to-create-a-Custom-Software-Fault-Injector-and-a-Custom-Instruction-Selector -// for more information. - -#include "llvm/Pass.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Support/CFG.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" - -#include "Utils.h" -#include "FIInstSelector.h" -#include "FICustomSelectorManager.h" -#include "_SoftwareFaultRegSelectors.h" - -#include -#include -#include -#include -#include - -using namespace llvm; -namespace llfi { -//fidl_1 - public: - virtual void getCompileTimeInfo(std::map& info) { -//fidl_2 - } - private: - virtual bool isInstFITarget(Instruction* inst) { -//fidl_3 - } - - static bool isTargetLLFIIndex(Instruction* inst) { -//fidl_4 - if (n > 0) { - long llfiindex = getLLFIIndexofInst(inst); - for (int i = 0; i < n; i++) { - if (llfiindex == targeted_indices[i]) { - return true; - } - } - return false; - } else { - return true; - } - } -}; - diff --git a/tools/FIDL/config/TargetMultiSourceTemplate.cpp b/tools/FIDL/config/TargetMultiSourceTemplate.cpp deleted file mode 100644 index 8a5cf51b..00000000 --- a/tools/FIDL/config/TargetMultiSourceTemplate.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// DO NOT MODIFY! -// File generated on - -// This file was generated from /tools/FIDL/TargetMultiSourceTemplate.cpp -// by the /tools/FIDL/FIDL-Algorithm.py -// See https://github.com/DependableSystemsLab/LLFI/wiki/Using-FIDL-to-create-a-Custom-Software-Fault-Injector-and-a-Custom-Instruction-Selector -// for more information. - -#include "llvm/Pass.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Support/CFG.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" - -#include "Utils.h" -#include "FIInstSelector.h" -#include "FICustomSelectorManager.h" -#include "_SoftwareFaultRegSelectors.h" - -#include -#include -#include -#include -#include - -using namespace llvm; -namespace llfi { -//fidl_1 - public: -//fidl_2 - if (funcNamesTargetArgs.size() == 0) { -//fidl_3 - } - } - - virtual void getCompileTimeInfo(std::map& info) { -//fidl_4 - for (std::map >::iterator MI = funcNamesTargetArgs.begin(); MI != funcNamesTargetArgs.end(); MI++) { - info["targets"] += MI->first + "()/"; - } - //remove the '/' at the end - info["targets"] = info["targets"].substr(0, info["targets"].length() - 1); -//fidl_5 - } - - static bool isTarget(CallInst* CI, Value* T) { - std::string func_name = CI->getCalledFunction()->getName(); -//fidl_6 - if (*SI >= CI->getNumArgOperands()) { - continue; - } else if (T == CI->getArgOperand(*SI)) { - return true; - } - } - return false; - } - - private: - static std::map > funcNamesTargetArgs; - - virtual bool isInstFITarget(Instruction* inst) { - if (!isa(inst)) { - return false; - } - - CallInst* CI = dyn_cast(inst); - Function* called_func = CI->getCalledFunction(); - if (called_func == NULL) { - return false; - } - - std::string func_name = std::string(called_func->getName()); -//fidl_7 - } - - static bool isTargetLLFIIndex(Instruction* inst) { -//fidl_8 - if (n > 0) { - long llfiindex = getLLFIIndexofInst(inst); - for (int i = 0; i < n; i++) { - if (llfiindex == targeted_indices[i]) { - return true; - } - } - return false; - } else { - return true; - } - } - - // does something in funcNamesTargetArgs matches partially with func_name? - static std::map >::iterator key_partially_matches(std::string func_name) { - std::map >::iterator SI; - for (SI = funcNamesTargetArgs.begin(); SI != funcNamesTargetArgs.end(); SI++) { - if (func_name.find(SI->first) != std::string::npos) { - break; - } - } - return SI; - } -}; - -//fidl_9 - private: - virtual bool isRegofInstFITarget(Value *reg, Instruction *inst) { - if (isa(inst) == false) { - return false; - } - CallInst* CI = dyn_cast(inst); - Function* called_func = CI->getCalledFunction(); - if (called_func == NULL) { - return false; - } -//fidl_10 - } else { - return false; - } - } -}; - diff --git a/tools/FIDL/config/TargetSingleTemplate.cpp b/tools/FIDL/config/TargetSingleTemplate.cpp deleted file mode 100644 index cc6292ea..00000000 --- a/tools/FIDL/config/TargetSingleTemplate.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// DO NOT MODIFY! -// File generated on - -// This file was generated from /tools/FIDL/TargetSingleTemplate.cpp -// by the /tools/FIDL/FIDL-Algorithm.py -// See https://github.com/DependableSystemsLab/LLFI/wiki/Using-FIDL-to-create-a-Custom-Software-Fault-Injector-and-a-Custom-Instruction-Selector -// for more information. - -#include "llvm/Pass.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Support/CFG.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" - -#include "Utils.h" -#include "FIInstSelector.h" -#include "FICustomSelectorManager.h" -#include "_SoftwareFaultRegSelectors.h" - -#include -#include -#include -#include -#include - -using namespace llvm; -namespace llfi { -//fidl_1 - public: -//fidl_2 - if (funcNames.size() == 0) { -//fidl_3 - } - } - - virtual void getCompileTimeInfo(std::map& info) { -//fidl_4 - for(std::set::iterator SI = funcNames.begin(); SI != funcNames.end(); SI++){ - info["targets"] += *SI + "()/"; - } - //remove the '/' at the end - info["targets"] = info["targets"].substr(0, info["targets"].length() - 1); -//fidl_5 - } - - private: - static std::set funcNames; - - virtual bool isInstFITarget(Instruction* inst) { - if (!isa(inst)) { - return false; - } - - CallInst* CI = dyn_cast(inst); - Function* called_func = CI->getCalledFunction(); - if (called_func == NULL) { - return false; - } - - std::string func_name = std::string(called_func->getName()); -//fidl_6 - } - - static bool isTargetLLFIIndex(Instruction* inst) { -//fidl_7 - if (n > 0) { - long llfiindex = getLLFIIndexofInst(inst); - for (int i = 0; i < n; i++) { - if (llfiindex == targeted_indices[i]) { - return true; - } - } - return false; - } else { - return true; - } - } - - static std::set::iterator key_partially_matches(std::string func_name) { - std::set::iterator SI; - for (SI = funcNames.begin(); SI != funcNames.end(); SI++) { - if (func_name.find(*SI) != std::string::npos) { - break; - } - } - return SI; - } -}; - diff --git a/tools/FIDL/config/default_failures.yaml b/tools/FIDL/config/default_failures.yaml deleted file mode 100644 index 6246d7cc..00000000 --- a/tools/FIDL/config/default_failures.yaml +++ /dev/null @@ -1,457 +0,0 @@ -- Failure_Class: API - Failure_Mode: BufferOverflow - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [2] - fwrite: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - -- Failure_Class: API - Failure_Mode: BufferUnderflow - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [2] - fwrite: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: -40 - -- Failure_Class: API - Failure_Mode: InappropriateClose - - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - dst - Action: - Perturb: InappropriateCloseInjector - option: True - -- Failure_Class: API - Failure_Mode: IncorrectOutput - - New_Failure_Mode: - Trigger: - return - Action: - Corrupt - -- Failure_Class: API - Failure_Mode: NoClose - - New_Failure_Mode: - Trigger: - call: [fclose] - Target: - src: - fclose: [0] - Action: - Perturb: InappropriateCloseInjector - option: False - -- Failure_Class: API - Failure_Mode: NoOpen - - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - src: - fopen: [0] - Action: - Corrupt - -- Failure_Class: API - Failure_Mode: NoOutput - - New_Failure_Mode: - Trigger: - return - Action: - Freeze - -- Failure_Class: API - Failure_Mode: WrongAPI - - New_Failure_Mode: - Trigger: - call: [fread, fwrite, fgetc, fopen] - Target: - src: - fread: [3] - fwrite: [3] - fgetc: [2] - fopen: [0, 1] - Action: - Corrupt - -- Failure_Class: API - Failure_Mode: WrongMode - - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - src: - fopen: [1] - Action: - Corrupt - -- Failure_Class: Data - Failure_Mode: BufferOverflowMalloc - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - src: - malloc: [0] - calloc: [0] - Action: - Perturb: ChangeValueInjector - option: False - value: -40 - -- Failure_Class: Data - Failure_Mode: BufferOverflowMemmove - - New_Failure_Mode: - Trigger: - call*: [memcpy, memmove] - Target: - src: - memcpy: [2] - memmove: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - -- Failure_Class: Data - Failure_Mode: DataCorruption - - New_Failure_Mode: - Trigger: - call: - [all] - Target: - src: - all: [0] - Action: - Corrupt - -- Failure_Class: Data - Failure_Mode: WrongDestination - - New_Failure_Mode: - Trigger: - call*: [memcpy, memmove, memcmp] - Target: - src: - memcpy: [0] - memmove: [0] - memcmp: [0] - Action: - Corrupt - -- Failure_Class: Data - Failure_Mode: WrongPointer - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [0] - fwrite: [0] - Action: - Corrupt - -- Failure_Class: Data - Failure_Mode: WrongSource - - New_Failure_Mode: - Trigger: - call*: [memcpy, memmove, memcmp] - Target: - src: - memcpy: [1] - memmove: [1] - memcmp: [1] - Action: - Corrupt - -- Failure_Class: IO - Failure_Mode: WrongRetrievedAddress - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [3] - fwrite: [0] - Action: - Corrupt - -- Failure_Class: IO - Failure_Mode: WrongRetrievedFormat - - New_Failure_Mode: - Trigger: - call: [fread] - Target: - src: - fread: [1] - Action: - Perturb: WrongFormatInjector - -- Failure_Class: IO - Failure_Mode: WrongSavedAddress - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [0] - fwrite: [3] - Action: - Corrupt - -- Failure_Class: IO - Failure_Mode: WrongSavedFormat - - New_Failure_Mode: - Trigger: - call: [fwrite] - Target: - src: - fwrite: [1] - Action: - Perturb: WrongFormatInjector - -- Failure_Class: MPI - Failure_Mode: DeadLock - - New_Failure_Mode: - Trigger: - call: [recv, send] - Target: - src: - recv: [0] - send: [0] - Action: - Corrupt - -- Failure_Class: MPI - Failure_Mode: InvalidMessage - - New_Failure_Mode: - Trigger: - call: [recv, send] - Target: - src: - recv: [1] - send: [1] - Action: - Perturb: ChangeValueInjector - option: False - value: 1024 - -- Failure_Class: MPI - Failure_Mode: InvalidSender - - New_Failure_Mode: - Trigger: - call: [connect, accept] - Target: - src: - connect: [1] - accept: [1] - Action: - Corrupt - -- Failure_Class: MPI - Failure_Mode: NoAck - - New_Failure_Mode: - Trigger: - call: [recv] - Target: - dst - Action: - Freeze - -- Failure_Class: MPI - Failure_Mode: NoDrain - - New_Failure_Mode: - Trigger: - call: [recv] - Target: - src: - recv: [3] - Action: - Perturb: ChangeValueInjector - option: True - value: 5000 - -- Failure_Class: MPI - Failure_Mode: NoMessage - - New_Failure_Mode: - Trigger: - call: [recv, send] - Target: - src: - recv: [1] - send: [1] - Action: - Freeze - -- Failure_Class: MPI - Failure_Mode: PacketStorm - - New_Failure_Mode: - Trigger: - call: [recv] - Target: - src: - recv: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: -40 - -- Failure_Class: Res - Failure_Mode: CPUHog - - New_Failure_Mode: - Trigger: - return - Action: - Delay - -- Failure_Class: Res - Failure_Mode: DeadLock - - New_Failure_Mode: - Trigger: - call: [pthread_join] - Target: - src: - pthread_join: [0] - Action: - Perturb: PthreadDeadLockInjector - -- Failure_Class: Res - Failure_Mode: InvalidPointer - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Corrupt - -- Failure_Class: Res - Failure_Mode: LowMemory - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Perturb: MemoryExhaustionInjector - option: False - -- Failure_Class: Res - Failure_Mode: MemoryExhaustion - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Perturb: MemoryExhaustionInjector - option: True - -- Failure_Class: Res - Failure_Mode: MemoryLeak - - New_Failure_Mode: - Trigger: - call: [free] - Target: - src: - free: [0] - Action: - Perturb: MemoryLeakInjector - -- Failure_Class: Res - Failure_Mode: StalePointer - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Perturb: StalePointerInjector - -- Failure_Class: Res - Failure_Mode: ThreadKiller - - New_Failure_Mode: - Trigger: - call: [pthread_join] - Target: - src: - pthread_join: [0] - Action: - Perturb: PthreadThreadKillerInjector - -- Failure_Class: Res - Failure_Mode: UnderAccumulator - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - src: - malloc: [0] - calloc: [0] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - -- Failure_Class: Timing - Failure_Mode: RaceCondition - - New_Failure_Mode: - Trigger: - call: [pthread_mutex_lock, pthread_mutex_trylock] - Target: - src: - pthread_mutex_lock: [0] - pthread_mutex_trylock: [0] - Action: - Perturb: PthreadRaceConditionInjector diff --git a/tools/FIDL/config/injectors.yaml b/tools/FIDL/config/injectors.yaml deleted file mode 100644 index 1d19d348..00000000 --- a/tools/FIDL/config/injectors.yaml +++ /dev/null @@ -1,3 +0,0 @@ -DO_NOT_MODIFY: DO_NOT_MODIFY -custom: {} -default: {} diff --git a/tools/FIDL/sample_scripts/fm1.yaml b/tools/FIDL/sample_scripts/fm1.yaml deleted file mode 100644 index 95bdfb6c..00000000 --- a/tools/FIDL/sample_scripts/fm1.yaml +++ /dev/null @@ -1,8 +0,0 @@ -Failure_Class: Class1 -Failure_Mode: FMode1 - -New_Failure_Mode: - Trigger: - call: [fopen, open] - Target: dst - Action: Corrupt diff --git a/tools/FIDL/sample_scripts/fm2.yaml b/tools/FIDL/sample_scripts/fm2.yaml deleted file mode 100644 index 5e1d5e61..00000000 --- a/tools/FIDL/sample_scripts/fm2.yaml +++ /dev/null @@ -1,17 +0,0 @@ -Failure_Class: Class2 -Failure_Mode: FMode2 - -New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Trigger*: [1, 50, 55, 60] - Target: - src: - fread: [2] - fwrite: [0] - Action: - Perturb: Custom_Injector - -Custom_Injector: | - int *Target = (int *) buf; - *Target = *Target + 1000; diff --git a/tools/FIDL/sample_scripts/fm3.yaml b/tools/FIDL/sample_scripts/fm3.yaml deleted file mode 100644 index 0c9a2058..00000000 --- a/tools/FIDL/sample_scripts/fm3.yaml +++ /dev/null @@ -1,13 +0,0 @@ -Failure_Class: Class3 -Failure_Mode: FMode3 - -New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [2] - fwrite: [0] - Action: - Perturb: InappropriateCloseInjector - option: False diff --git a/tools/FIDL/sample_scripts/fm4.yaml b/tools/FIDL/sample_scripts/fm4.yaml deleted file mode 100644 index 26193544..00000000 --- a/tools/FIDL/sample_scripts/fm4.yaml +++ /dev/null @@ -1,9 +0,0 @@ -Failure_Class: Class4 -Failure_Mode: FMode4 - -New_Failure_Mode: - Trigger: - return - Trigger*: [1, 50, 55, 60] - Action: - Corrupt diff --git a/tools/FIDL/tests/test_FIDL.py b/tools/FIDL/tests/test_FIDL.py deleted file mode 100755 index b43f0a59..00000000 --- a/tools/FIDL/tests/test_FIDL.py +++ /dev/null @@ -1,222 +0,0 @@ -#! /usr/bin/env python3 - -""" -This script is DEPRECATED! This was used to initially test if the 38/39 software failures -converted from .cpp format to the FIDL yaml format was working properly. - -Tests FIDL-Algorithm.py. - -Usage: %(prog)s [OPTIONS] - --a: adds FIDL test cases --r: removes FIDL test cases --t: executes tests --h: shows help - -Steps: -1. Execute with -a in LLFI_SRC -2. Build LLFI -3. Execute this script in LLFI_DST with -t -4. Execute -r in LLFI_SRC and rebuild LLFI to get rid of test cases - -Run this file *after* successfully executing the llfi regression test -This tests compares the generated FIDL-Algorithm.py's generated faultinjection.ll and profiling.ll -against the actual faultinjection.ll and profiling.ll from the real selector cpp file. -""" - -import sys, os, subprocess, shutil -import yaml -from distutils.dir_util import copy_tree - -script_path = os.path.realpath(os.path.dirname(__file__)) -prog = os.path.basename(sys.argv[0]) - -llfiroot = os.path.dirname(os.path.dirname(os.path.dirname(script_path))) - -test_config_path = os.path.join(script_path, 'test_config.yaml') -fidl_al_path = os.path.join(script_path, '../FIDL-Algorithm.py') - -programs_dir = os.path.join(llfiroot, 'test_suite/PROGRAMS/') -fidl_config_dir = os.path.join(script_path, 'fidl_config/') -fidl_tests_dir = os.path.join(script_path, 'fidl_test/') - -bin_path = os.path.join(script_path, '../../../bin') -instrument_path = os.path.join(bin_path, 'instrument') -profile_path = os.path.join(bin_path, 'profile') -injectfault_path = os.path.join(bin_path, 'injectfault') - -ir_ext = '.ll' -expected = 'Expected' -output = 'Output' - -def ir_file_equals(ir1_path, ir2_path): - with open(ir1_path) as f: - lines1 = f.read().splitlines() - with open(ir2_path) as f: - lines2 = f.read().splitlines() - - if len(lines1) != len(lines2): - return False - - # check the file content (ignoring the first line where the path - # of the files will be different - for i in range(1, len(lines1)): - if lines1[i] != lines2[i]: - return False - - return True - -def is_same_result(test_dir_path, program_name): - file_list = ['llfi/%s-faultinjection%s' % (program_name, ir_ext), 'llfi/%s-profiling%s' % (program_name, ir_ext)] - - # check if the generated ir files are equal - for n in file_list: - if not ir_file_equals(os.path.join(test_dir_path, expected, n), os.path.join(test_dir_path, output, n)): - return False - - return True - -def execute_tests(): - global doc - - # delete and create a new tests folder - del_mkdir(fidl_tests_dir) - - # create folder for each test case - for n in doc['tests']: - dir_name, name = extract_names(n) - dir_path = os.path.join(fidl_tests_dir, dir_name) - os.makedirs(dir_path) - print('Testing %s' % dir_name) - - program_name = n['config']['program'] - - l = [[expected, n['config']['simulate']], [output, name]] - for i in l: - # create inner directory and cd to it - inner_dir_path = os.path.join(dir_path, i[0]) - os.makedirs(inner_dir_path) - os.chdir(inner_dir_path) - - # creates the input.yaml - dump_yaml(os.path.join(inner_dir_path, 'input.yaml'), create_input_yaml(n, i[1])) - - # copy in the program - copy_tree(os.path.join(programs_dir, program_name), inner_dir_path) - - # instrument - execlist = [instrument_path, '--readable', '-lpthread', program_name + ir_ext] - ret_val = subprocess.call(execlist, stdout = open(os.devnull, 'wb'), stderr = open(os.devnull, 'wb')) - if (ret_val != 0): - print('Error: Instrument failed!') - exit(1) - - # check results - if is_same_result(dir_path, program_name): - print('Success: %s' % dir_name) - else: - print('Error: %s' % dir_name) - -def create_input_yaml(test, selector): - global doc - - template = doc['inputTemplate'].copy() - template['compileOption']['instSelMethod'][0]['customInstselector']['include'] = [selector] - - return template - -def del_mkdir(dir_path): - if os.path.exists(dir_path): - shutil.rmtree(dir_path) - os.makedirs(dir_path) - -def extract_names(test): - f_mode = test['FIDL']['Failure_Mode'] - f_class = test['FIDL']['Failure_Class'] - - filename = '_%s_%s' % (f_class, f_mode) - name = '%s(%s)' % (f_mode, f_class) - - return (filename, name) - -def dump_yaml(path, yaml_object): - f = open(path, 'w') - f.write(yaml.dump(yaml_object)) - f.close() - -def run_fidl_algorithm(add): - global doc - - # delete and create a new fidl script config(s) folder - if add: - del_mkdir(fidl_config_dir) - - for n in doc['tests']: - filename, name = extract_names(n) - filename = filename + '.yaml' - - # create new fidl script from ones specified in test_config.yaml - filename_path = os.path.join(fidl_config_dir, filename) - dump_yaml(filename_path, n['FIDL']) - - if add: - option = [fidl_al_path, '-a', filename_path] - else: - option = [fidl_al_path, '-r', name] - - #TODO what is the proper way for this? - # redirect error and such... - - # executes the fidl algorithm on the script - retVal = subprocess.call(option) - if retVal != 0: - print('Error: %s is not a valid fidl script!' % filename) - exit(1) - - # delete fidl script config(s) folder if removing - if os.path.exists(fidl_config_dir) and not add: - shutil.rmtree(fidl_config_dir) - -def remove_tests(): - run_fidl_algorithm(False) - -def add_tests(): - run_fidl_algorithm(True) - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = 'ERROR: ' + msg - print(msg, file = sys.stderr) - print(__doc__ % globals(), file = sys.stderr) - sys.exit(retval) - -def parse_args(args): - option = args[0] - if option == '-a': - add_tests() - elif option == '-r': - remove_tests() - elif option == '-t': - execute_tests() - elif option == '-h': - usage() - else: - usage('Invalid Argument: ' + option) - -def read_yaml(): - global doc - f = open(test_config_path) - doc = yaml.safe_load(f) - f.close() - -def main(args): - print('This script is DEPRECATED! See -h') - read_yaml() - parse_args(args) - -if __name__ == '__main__': - if len(sys.argv) == 1: - usage('Please specify an option.') - main(sys.argv[1:]) diff --git a/tools/FIDL/tests/test_config.yaml b/tools/FIDL/tests/test_config.yaml deleted file mode 100644 index 65f105fe..00000000 --- a/tools/FIDL/tests/test_config.yaml +++ /dev/null @@ -1,773 +0,0 @@ -#Deprecated -#This is the configuration yaml for test_FIDL.py - -tests: -- FIDL: #Tests Custom_Injector, and return value register targeting - Failure_Class: TestFIDL - Failure_Mode: custom - New_Failure_Mode: - Trigger: - return - Action: - Perturb: Custom_Injector - Custom_Injector: | - sleep(3); - config: - program: memcpy1 - simulate: CPUHog(Res) - -- FIDL: #Tests destination register targeting, and option for perturb - Failure_Class: TestFIDL - Failure_Mode: dst - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - dst - Action: - Perturb: InappropriateCloseInjector - option: True - config: - program: memcpy1 - simulate: InappropriateClose(API) - -- FIDL: #Tests multi-source register targeting - Failure_Class: TestFIDL - Failure_Mode: multisrc - New_Failure_Mode: - Trigger: - call: [fread, fwrite, fgetc, fopen] - Target: - src: - fread: [3] - fwrite: [3] - fgetc: [2] - fopen: [0, 1] - Action: - Corrupt - config: - program: memcpy1 - simulate: WrongAPI(API) - -- FIDL: #Tests multi-source register targeting - Failure_Class: TestFIDL - Failure_Mode: multisrc1 - New_Failure_Mode: - Trigger: - call*: [fread, fwrite, fgetc, fopen] - Target: - src: - fread: [3] - fwrite: [3] - fgetc: [2] - fopen: [0, 1] - Action: - Corrupt - config: - program: memcpy1 - simulate: WrongAPI(API) - -- FIDL: #Test single source register targeting, and option/value for perturb - Failure_Class: TestFIDL - Failure_Mode: singlesrc - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [2] - fwrite: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - config: - program: memcpy1 - simulate: BufferOverflow(API) - -- FIDL: - Failure_Class: API - Failure_Mode: BufferOverflow1 - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [2] - fwrite: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - config: - program: memcpy1 - simulate: BufferOverflow(API) - -- FIDL: - Failure_Class: API - Failure_Mode: BufferUnderflow1 - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [2] - fwrite: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: -40 - - config: - program: memcpy1 - simulate: BufferUnderflow(API) - -- FIDL: - Failure_Class: API - Failure_Mode: InappropriateClose1 - - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - dst - Action: - Perturb: InappropriateCloseInjector - option: True - - config: - program: memcpy1 - simulate: InappropriateClose(API) - -- FIDL: - Failure_Class: API - Failure_Mode: IncorrectOutput1 - - New_Failure_Mode: - Trigger: - return - Action: - Corrupt - - config: - program: mcf - simulate: IncorrectOutput(API) - -- FIDL: - Failure_Class: API - Failure_Mode: NoClose1 - - New_Failure_Mode: - Trigger: - call: [fclose] - Target: - src: - fclose: [0] - Action: - Perturb: InappropriateCloseInjector - option: False - - config: - program: memcpy1 - simulate: NoClose(API) - -- FIDL: - Failure_Class: API - Failure_Mode: NoOpen1 - - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - src: - fopen: [0] - Action: - Corrupt - - config: - program: memcpy1 - simulate: NoOpen(API) - -- FIDL: - Failure_Class: API - Failure_Mode: NoOutput1 - - New_Failure_Mode: - Trigger: - return - Action: - Freeze - - config: - program: sudoku2 - simulate: NoOutput(API) - -- FIDL: - Failure_Class: API - Failure_Mode: WrongAPI1 - - New_Failure_Mode: - Trigger: - call: [fread, fwrite, fgetc, fopen] - Target: - src: - fread: [3] - fwrite: [3] - fgetc: [2] - fopen: [0, 1] - Action: - Corrupt - - config: - program: mcf - simulate: WrongAPI(API) - -- FIDL: - Failure_Class: API - Failure_Mode: WrongMode1 - - New_Failure_Mode: - Trigger: - call: [fopen] - Target: - src: - fopen: [1] - Action: - Corrupt - - config: - program: memcpy1 - simulate: WrongMode(API) - -- FIDL: - Failure_Class: Data - Failure_Mode: BufferOverflowMalloc1 - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - src: - malloc: [0] - calloc: [0] - Action: - Perturb: ChangeValueInjector - option: False - value: -40 - - config: - program: memcpy1 - simulate: BufferOverflowMalloc(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: BufferOverflowMemmove1 - - New_Failure_Mode: - Trigger: - call*: [memcpy, memmove] - Target: - src: - memcpy: [2] - memmove: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - - config: - program: memcpy1 - simulate: BufferOverflowMemmove(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: DataCorruption1 - - New_Failure_Mode: - Trigger: - call: - [all] - Target: - src: - all: [0] - Action: - Corrupt - config: - program: mcf - simulate: DataCorruption(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: IncorrectOutput1 - - New_Failure_Mode: - Trigger: - return - Action: - Corrupt - - config: - program: mcf - simulate: IncorrectOutput(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: NoOutput1 - - New_Failure_Mode: - Trigger: - return - Action: - Freeze - config: - program: mcf - simulate: NoOutput(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: WrongDestination1 - - New_Failure_Mode: - Trigger: - call*: [memcpy, memmove, memcmp] - Target: - src: - memcpy: [0] - memmove: [0] - memcmp: [0] - Action: - Corrupt - - config: - program: memcpy1 - simulate: WrongDestination(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: WrongPointer1 - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [0] - fwrite: [0] - Action: - Corrupt - - config: - program: memcpy1 - simulate: WrongPointer(Data) - -- FIDL: - Failure_Class: Data - Failure_Mode: WrongSource1 - - New_Failure_Mode: - Trigger: - call*: [memcpy, memmove, memcmp] - Target: - src: - memcpy: [1] - memmove: [1] - memcmp: [1] - Action: - Corrupt - - config: - program: memcpy1 - simulate: WrongSource(Data) - -- FIDL: - Failure_Class: IO - Failure_Mode: WrongRetrievedAddress1 - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [3] - fwrite: [0] - Action: - Corrupt - - config: - program: memcpy1 - simulate: WrongRetrievedAddress(IO) - -- FIDL: - Failure_Class: IO - Failure_Mode: WrongRetrievedFormat1 - - New_Failure_Mode: - Trigger: - call: [fread] - Target: - src: - fread: [1] - Action: - Perturb: WrongFormatInjector - config: - program: memcpy1 - simulate: WrongRetrievedFormat(IO) - -- FIDL: - Failure_Class: IO - Failure_Mode: WrongSavedAddress1 - - New_Failure_Mode: - Trigger: - call: [fread, fwrite] - Target: - src: - fread: [0] - fwrite: [3] - Action: - Corrupt - - config: - program: memcpy1 - simulate: WrongSavedAddress(IO) - -- FIDL: - Failure_Class: IO - Failure_Mode: WrongSavedFormat1 - - New_Failure_Mode: - Trigger: - call: [fwrite] - Target: - src: - fwrite: [1] - Action: - Perturb: WrongFormatInjector - - config: - program: memcpy1 - simulate: WrongSavedFormat(IO) - -- FIDL: - Failure_Class: MPI - Failure_Mode: DeadLock1 - - New_Failure_Mode: - Trigger: - call: [recv, send] - Target: - src: - recv: [0] - send: [0] - Action: - Corrupt - - config: - program: mpi - simulate: DeadLock(MPI) - -- FIDL: - Failure_Class: MPI - Failure_Mode: InvalidMessage1 - - New_Failure_Mode: - Trigger: - call: [recv, send] - Target: - src: - recv: [1] - send: [1] - Action: - Perturb: ChangeValueInjector - option: False - value: 1024 - - config: - program: mpi - simulate: InvalidMessage(MPI) - -- FIDL: - Failure_Class: MPI - Failure_Mode: InvalidSender1 - - New_Failure_Mode: - Trigger: - call: [connect, accept] - Target: - src: - connect: [1] - accept: [1] - Action: - Corrupt - - config: - program: mpi - simulate: InvalidSender(MPI) - -- FIDL: - Failure_Class: MPI - Failure_Mode: NoAck1 - - New_Failure_Mode: - Trigger: - call: [recv] - Target: - dst - Action: - Freeze - - config: - program: mpi - simulate: NoAck(MPI) - -- FIDL: - Failure_Class: MPI - Failure_Mode: NoDrain1 - - New_Failure_Mode: - Trigger: - call: [recv] - Target: - src: - recv: [3] - Action: - Perturb: ChangeValueInjector - option: True - value: 5000 - - config: - program: mpi - simulate: NoDrain(MPI) - -- FIDL: - Failure_Class: MPI - Failure_Mode: NoMessage1 - - New_Failure_Mode: - Trigger: - call: [recv, send] - Target: - src: - recv: [1] - send: [1] - Action: - Freeze - - config: - program: mpi - simulate: NoMessage(MPI) - -- FIDL: - Failure_Class: MPI - Failure_Mode: PacketStorm1 - - New_Failure_Mode: - Trigger: - call: [recv] - Target: - src: - recv: [2] - Action: - Perturb: ChangeValueInjector - option: False - value: -40 - - config: - program: mpi - simulate: PacketStorm(MPI) - -- FIDL: - Failure_Class: Res - Failure_Mode: CPUHog1 - - New_Failure_Mode: - Trigger: - return - Action: - Delay - - config: - program: mcf - simulate: CPUHog(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: DeadLock1 - - New_Failure_Mode: - Trigger: - call: [pthread_join] - Target: - src: - pthread_join: [0] - Action: - Perturb: PthreadDeadLockInjector - - config: - program: deadlock - simulate: DeadLock(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: InvalidPointer1 - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Corrupt - - config: - program: mcf - simulate: InvalidPointer(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: LowMemory1 - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Perturb: MemoryExhaustionInjector - option: False - - config: - program: mcf - simulate: LowMemory(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: MemoryExhaustion1 - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Perturb: MemoryExhaustionInjector - option: True - - config: - program: memcpy1 - simulate: MemoryExhaustion(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: MemoryLeak1 - - New_Failure_Mode: - Trigger: - call: [free] - Target: - src: - free: [0] - Action: - Perturb: MemoryLeakInjector - - config: - program: mcf - simulate: MemoryLeak(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: StalePointer1 - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - dst - Action: - Perturb: StalePointerInjector - - config: - program: memcpy1 - simulate: StalePointer(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: ThreadKiller1 - - New_Failure_Mode: - Trigger: - call: [pthread_create, pthread_join] - Target: - src: - pthread_create: [0] - pthread_join: [0] - Action: - Perturb: PthreadThreadKillerInjector - - config: - program: deadlock - simulate: ThreadKiller(Res) - -- FIDL: - Failure_Class: Res - Failure_Mode: UnderAccumulator1 - - New_Failure_Mode: - Trigger: - call: [malloc, calloc] - Target: - src: - malloc: [0] - calloc: [0] - Action: - Perturb: ChangeValueInjector - option: False - value: 45 - - config: - program: memcpy1 - simulate: UnderAccumulator(Res) - -- FIDL: - Failure_Class: Timing - Failure_Mode: RaceCondition1 - - New_Failure_Mode: - Trigger: - call: [pthread_mutex_lock] - Target: - dst - Action: - Perturb: PthreadRaceConditionInjector - - config: - program: deadlock - simulate: RaceCondition(Timing) - -inputTemplate: - kernelOption: [forceRun] - compileOption: - instSelMethod: - - customInstselector: - include: [] - - regSelMethod: customregselector - customRegSelector: Automatic - - includeInjectionTrace: - - forward - - backward - - tracingPropagation: True - - tracingPropagationOption: - debugTrace: True/False - generateCDFG: True - - runOption: - - run: - numOfRuns: 1 - fi_type: AutoInjection diff --git a/tools/GenerateMakefile.py b/tools/GenerateMakefile.py index 18b49ca3..689b906a 100755 --- a/tools/GenerateMakefile.py +++ b/tools/GenerateMakefile.py @@ -19,8 +19,9 @@ """ import sys, os, subprocess, tempfile + script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) +sys.path.append(os.path.join(script_path, "../config")) import llvm_paths import re import glob @@ -32,170 +33,176 @@ clang = os.path.join(llvm_paths.LLVM_GXX_BIN_DIR, "clang") clangxx = os.path.join(llvm_paths.LLVM_GXX_BIN_DIR, "clang++") -fname = 'Makefile' -newline = '\n\n' -indent = '\n\t' +fname = "Makefile" +newline = "\n\n" +indent = "\n\t" filetypes = [".c", ".C", ".cpp", ".cxx", ".cp", ".CPP", ".cc"] options = { - "o": "a.out", - "sources": [], - "readable": False, - "debug": False, - "verbose": False, - "all": False, - "dir": "", - "flags": [], + "o": "a.out", + "sources": [], + "readable": False, + "debug": False, + "verbose": False, + "all": False, + "dir": "", + "flags": [], } -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) def verbosePrint(msg, verbose): - if verbose: - print(msg) + if verbose: + print(msg) def parseArgs(args): - global options - argid = 0 - while argid < len(args): - arg = args[argid] - if arg.startswith("-"): - if arg == "-o": - argid += 1 - options["o"] = args[argid] - elif arg == "--readable": - options["readable"] = True - elif arg == "--verbose": - options["verbose"] = True - elif arg == "--debug": - options["debug"] = True - elif arg == "--all": - options["all"] = True - elif arg == "--dir": + global options + argid = 0 + while argid < len(args): + arg = args[argid] + if arg.startswith("-"): + if arg == "-o": + argid += 1 + options["o"] = args[argid] + elif arg == "--readable": + options["readable"] = True + elif arg == "--verbose": + options["verbose"] = True + elif arg == "--debug": + options["debug"] = True + elif arg == "--all": + options["all"] = True + elif arg == "--dir": + argid += 1 + options["dir"] = args[argid] + elif arg == "--flags": + argid += 1 + while argid < len(args) and not (args[argid].startswith("-")): + options["flags"].append(args[argid]) + argid += 1 + argid -= 1 + elif arg == "--help" or arg == "-h": + usage() + else: + usage("Invalid argument: " + arg) + else: + options["sources"].append(arg) argid += 1 - options["dir"] = args[argid] - elif arg == "--flags": - argid += 1 - while argid < len(args) and not (args[argid].startswith('-')): - options["flags"].append(args[argid]) - argid += 1 - argid -= 1 - elif arg == "--help" or arg == "-h": - usage() - else: - usage("Invalid argument: " + arg) - else: - options["sources"].append(arg) - argid += 1 - if len(options["sources"]) == 0 and options["all"] == False: - usage("No input file(s) specified.") + if len(options["sources"]) == 0 and options["all"] == False: + usage("No input file(s) specified.") def selectCompiler(sourceFiles): - for inputfile in sourceFiles: - if inputfile.endswith(".cpp"): - return clangxx - return clang + for inputfile in sourceFiles: + if inputfile.endswith(".cpp"): + return clangxx + return clang -#Build the header for the Makefile +# Build the header for the Makefile def initializeMakefile(sourceFiles): - with open(fname, 'w') as fout: - fout.write('CC=' + selectCompiler(sourceFiles) + '\n') - fout.write('LINKER=' + llvmlink + '\n') - fout.write('OUTPUT=' + options["o"] + '\n') - - cflags = ['-w', '-emit-llvm', '-fno-use-cxa-atexit', '-g'] #Default compiler flags - lflags = ['-o', '$(OUTPUT)'] #Default linker flags - - if options['readable']: - cflags.append('-S') - lflags.append('-S') - else: - cflags.append('-c') + with open(fname, "w") as fout: + fout.write("CC=" + selectCompiler(sourceFiles) + "\n") + fout.write("LINKER=" + llvmlink + "\n") + fout.write("OUTPUT=" + options["o"] + "\n") + + cflags = [ + "-w", + "-emit-llvm", + "-fno-use-cxa-atexit", + "-g", + ] # Default compiler flags + lflags = ["-o", "$(OUTPUT)"] # Default linker flags + + if options["readable"]: + cflags.append("-S") + lflags.append("-S") + else: + cflags.append("-c") - if options['debug']: - cflags.append('-g') + if options["debug"]: + cflags.append("-g") - if options['flags']: - additionalFlags = ['-' + flag for flag in options['flags']] - cflags += additionalFlags + if options["flags"]: + additionalFlags = ["-" + flag for flag in options["flags"]] + cflags += additionalFlags - fout.write('CFLAGS=' + " ".join(cflags) + '\n') - fout.write('LINKER_FLAGS=' + " ".join(lflags) + '\n') + fout.write("CFLAGS=" + " ".join(cflags) + "\n") + fout.write("LINKER_FLAGS=" + " ".join(lflags) + "\n") -#Define the body of the Makefile +# Define the body of the Makefile def constructMakeFile(sourceFiles): - objList = [] - - if options['readable']: - fileextension = '.ll' - else: - fileextension = '.bc' - - with open(fname, 'a') as fout: - fout.write('SRCDIR_OBJS=') - - fileexts = ['\\' + filetype + '$' for filetype in filetypes] - regex = '|'.join(fileexts) - - for codeFile in sourceFiles: - objFile = re.sub(regex, fileextension, codeFile) - objList.append(objFile) - fout.write(objFile + ' ') - - fout.write(newline) - fout.write('build:') - - fout.write('$(SRCDIR_OBJS)') - - fout.write(indent) - fout.write('$(LINKER) ') - fout.write('$(LINKER_FLAGS) ') - fout.write('$(SRCDIR_OBJS)' + newline) - - for codeFile in sourceFiles: - objFile = re.sub(regex, fileextension, codeFile) - fout.write(objFile + ': ' + codeFile) - fout.write(indent) - fout.write('$(CC) $(CFLAGS) ' + codeFile) - fout.write(newline) - - fout.write('clean:') - fout.write(indent) - fout.write('rm -rf *.ll *.bc llfi* ') - fout.write(options["o"]) - fout.write(newline) + objList = [] + + if options["readable"]: + fileextension = ".ll" + else: + fileextension = ".bc" + + with open(fname, "a") as fout: + fout.write("SRCDIR_OBJS=") + + fileexts = ["\\" + filetype + "$" for filetype in filetypes] + regex = "|".join(fileexts) + + for codeFile in sourceFiles: + objFile = re.sub(regex, fileextension, codeFile) + objList.append(objFile) + fout.write(objFile + " ") + + fout.write(newline) + fout.write("build:") + + fout.write("$(SRCDIR_OBJS)") + + fout.write(indent) + fout.write("$(LINKER) ") + fout.write("$(LINKER_FLAGS) ") + fout.write("$(SRCDIR_OBJS)" + newline) + + for codeFile in sourceFiles: + objFile = re.sub(regex, fileextension, codeFile) + fout.write(objFile + ": " + codeFile) + fout.write(indent) + fout.write("$(CC) $(CFLAGS) " + codeFile) + fout.write(newline) + + fout.write("clean:") + fout.write(indent) + fout.write("rm -rf *.ll *.bc llfi* ") + fout.write(options["o"]) + fout.write(newline) def main(args): - parseArgs(args) + parseArgs(args) - if (options["dir"]): - os.chdir(options["dir"]) #Change working directory if dir is specified + if options["dir"]: + os.chdir(options["dir"]) # Change working directory if dir is specified - if (options["all"]): #Read all C/C++ files in the project directory - fileexts = ['*' + filetype for filetype in filetypes] - sourceFiles = [] - for files in fileexts: - sourceFiles.extend(glob.glob(files)) - else: - sourceFiles = options["sources"] + if options["all"]: # Read all C/C++ files in the project directory + fileexts = ["*" + filetype for filetype in filetypes] + sourceFiles = [] + for files in fileexts: + sourceFiles.extend(glob.glob(files)) + else: + sourceFiles = options["sources"] - initializeMakefile(sourceFiles) - constructMakeFile(sourceFiles) + initializeMakefile(sourceFiles) + constructMakeFile(sourceFiles) -if __name__=="__main__": - main(sys.argv[1:]) +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tools/compiletoIR.py b/tools/compiletoIR.py index 80634982..5239289c 100755 --- a/tools/compiletoIR.py +++ b/tools/compiletoIR.py @@ -17,9 +17,13 @@ --help(-h): Show help information """ -import sys, os, subprocess, tempfile +import os +import subprocess +import sys +import tempfile + script_path = os.path.realpath(os.path.dirname(__file__)) -sys.path.append(os.path.join(script_path, '../config')) +sys.path.append(os.path.join(script_path, "../config")) import llvm_paths llvmlink = os.path.join(llvm_paths.LLVM_DST_ROOT, "bin/llvm-link") @@ -30,140 +34,144 @@ basedir = os.getcwd() options = { - "o": "a.out", - "sources": [], - "I": [], - "readable": False, - "debug": False, - "verbose": False, + "o": "a.out", + "sources": [], + "I": [], + "readable": False, + "debug": False, + "verbose": False, } -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) def verbosePrint(msg, verbose): - if verbose: - print(msg) + if verbose: + print(msg) def parseArgs(args): - global options - argid = 0 - while argid < len(args): - arg = args[argid] - if arg.startswith("-"): - if arg == "-o": - argid += 1 - options["o"] = os.path.join(basedir, args[argid]) - elif arg == "-I": + global options + argid = 0 + while argid < len(args): + arg = args[argid] + if arg.startswith("-"): + if arg == "-o": + argid += 1 + options["o"] = os.path.join(basedir, args[argid]) + elif arg == "-I": + argid += 1 + options["I"].append(os.path.join(basedir, args[argid])) + elif arg == "--readable": + options["readable"] = True + elif arg == "--verbose": + options["verbose"] = True + elif arg == "--debug": + options["debug"] = True + elif arg == "--help" or arg == "-h": + usage() + else: + usage("Invalid argument: " + arg) + else: + options["sources"].append(os.path.join(basedir, arg)) argid += 1 - options["I"].append(os.path.join(basedir, args[argid])) - elif arg == "--readable": - options["readable"] = True - elif arg == "--verbose": - options["verbose"] = True - elif arg == "--debug": - options["debug"] = True - elif arg == "--help" or arg == "-h": - usage() - else: - usage("Invalid argument: " + arg) - else: - options["sources"].append(os.path.join(basedir, arg)) - argid += 1 - if len(options["sources"]) == 0: - usage("No input file(s) specified.") + if len(options["sources"]) == 0: + usage("No input file(s) specified.") ################################################################################ def execute(execlist): - verbosePrint(' '.join(execlist), options["verbose"]) - p = subprocess.Popen(execlist) - p.wait() - return p.returncode + verbosePrint(" ".join(execlist), options["verbose"]) + p = subprocess.Popen(execlist) + p.wait() + return p.returncode def compileToIR(outputfile, inputfile): - if inputfile.endswith(".c"): - execlist = [llvmgcc] - else: - execlist = [llvmgxx] + if inputfile.endswith(".c"): + execlist = [llvmgcc] + else: + execlist = [llvmgxx] - execlist.extend(['-w', '-emit-llvm', '-o', outputfile, inputfile]) + execlist.extend(["-w", "-emit-llvm", "-o", outputfile, inputfile]) - for header_dir in options["I"]: - execlist.extend(['-I', header_dir]) + for header_dir in options["I"]: + execlist.extend(["-I", header_dir]) - if options['readable']: - execlist.append('-S') - else: - execlist.append('-c') + if options["readable"]: + execlist.append("-S") + else: + execlist.append("-c") - if options['debug']: - execlist.append('-g') + if options["debug"]: + execlist.append("-g") - return execute(execlist) + return execute(execlist) def linkFiles(outputfile, inputlist): - execlist = [llvmlink, '-o', outputfile] + execlist = [llvmlink, "-o", outputfile] + + if options["readable"]: + execlist.append("-S") - if options['readable']: - execlist.append('-S') + execlist.extend(inputlist) + return execute(execlist) - execlist.extend(inputlist) - return execute(execlist) ################################################################################ def compileProg(): - outputfile = options["o"] - srcfiles = options["sources"] - verbosePrint("Source files to be compiled: ", options["verbose"]) - verbosePrint(", ".join(srcfiles), options["verbose"]) - verbosePrint("\n======Compile======", options["verbose"]) - - if len(srcfiles) == 1: - retcode = compileToIR(outputfile, srcfiles[0]) - else: - tmpfiles = [] - for src in srcfiles: - file_handler, tmpfile = tempfile.mkstemp() - tmpfiles.append(tmpfile) - retcode = compileToIR(tmpfile, src) - if retcode != 0: - break - - if retcode == 0: - retcode = linkFiles(outputfile, tmpfiles) - - # cleaning up the temporary files - for tmpfile in tmpfiles: - try: - os.remove(tmpfile) - except: - pass - - if retcode != 0: - print("\nERROR: there was a compilation error, please follow"\ - " the provided instructions for %s or compile the "\ - "source file(s) to one single IR file manually." % prog, file=sys.stderr) - sys.exit(retcode) + outputfile = options["o"] + srcfiles = options["sources"] + verbosePrint("Source files to be compiled: ", options["verbose"]) + verbosePrint(", ".join(srcfiles), options["verbose"]) + verbosePrint("\n======Compile======", options["verbose"]) + + if len(srcfiles) == 1: + retcode = compileToIR(outputfile, srcfiles[0]) + else: + tmpfiles = [] + for src in srcfiles: + file_handler, tmpfile = tempfile.mkstemp() + tmpfiles.append(tmpfile) + retcode = compileToIR(tmpfile, src) + if retcode != 0: + break + + if retcode == 0: + retcode = linkFiles(outputfile, tmpfiles) + + # cleaning up the temporary files + for tmpfile in tmpfiles: + try: + os.remove(tmpfile) + except Exception: + pass + + if retcode != 0: + print( + "\nERROR: there was a compilation error, please follow" + " the provided instructions for %s or compile the " + "source file(s) to one single IR file manually." % prog, + file=sys.stderr, + ) + sys.exit(retcode) ################################################################################ def main(args): - parseArgs(args) - compileProg() + parseArgs(args) + compileProg() -if __name__=="__main__": - main(sys.argv[1:]) +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tools/create_input_yaml.py b/tools/create_input_yaml.py index ad592f35..8976bd1e 100644 --- a/tools/create_input_yaml.py +++ b/tools/create_input_yaml.py @@ -34,7 +34,7 @@ fi_max_multiple: 1 """ -''' +""" 31069400218890051 ---> onnx.Constant 28552536314307922 ---> onnx.Reshape 1986948931 ---> onnx.Conv @@ -43,32 +43,74 @@ 31367363490312788 ---> onnx.Transpose 1835885895 ---> onnx.Gemm 33884119937478483 ---> onnx.Softmax -''' +""" + def getONNXId(name): - dict = {'Conv':1986948931, 'Relu':1970038098, 'MaxPool':30521821366870349, - 'MatMul':119251066446157, 'Add':6579265, 'AvgPool':30521821365761601, - 'Softmax':33884119937478483} + dict = { + "Conv": 1986948931, + "Relu": 1970038098, + "MaxPool": 30521821366870349, + "MatMul": 119251066446157, + "Add": 6579265, + "AvgPool": 30521821365761601, + "Softmax": 33884119937478483, + } + + return dict[name] - return dict[name]; # Inject faults into these layers only -whiteList = ['Conv', 'Relu', 'MaxPool', 'Add', 'MatMul', 'AvgPool'] +whiteList = ["Conv", "Relu", "MaxPool", "Add", "MatMul", "AvgPool"] + def main(): global content parser = argparse.ArgumentParser() - parser.add_argument("-l", "--layer", action="store", type=str, default="1", help="Layer for FI \t Default: 1") - parser.add_argument("-r", "--runs", action="store", type=str, default="500", help="Number of runs \t Default: 500") - parser.add_argument("-f", "--fault", action="store", type=str, default="1", help="FI_MAX_MULTIPLE \t Default: 1") - parser.add_argument("--random_layer", action="store_true", default=False, help="Do you want to randomly select a layer? \t Default: False") - parser.add_argument("--model_path", action="store", default="None", help="Path to Model.onnx file. It 's required in the case of random layer selection' \t Default: None") + parser.add_argument( + "-l", + "--layer", + action="store", + type=str, + default="1", + help="Layer for FI \t Default: 1", + ) + parser.add_argument( + "-r", + "--runs", + action="store", + type=str, + default="500", + help="Number of runs \t Default: 500", + ) + parser.add_argument( + "-f", + "--fault", + action="store", + type=str, + default="1", + help="FI_MAX_MULTIPLE \t Default: 1", + ) + parser.add_argument( + "--random_layer", + action="store_true", + default=False, + help="Do you want to randomly select a layer? \t Default: False", + ) + parser.add_argument( + "--model_path", + action="store", + default="None", + help="Path to Model.onnx file. It 's required in the case of random layer selection' \t Default: None", + ) args = parser.parse_args() - content = content.replace("numOfRuns: 500", "numOfRuns: "+str(args.runs)) - content = content.replace("fi_max_multiple: 1", "fi_max_multiple: "+str(args.fault)) + content = content.replace("numOfRuns: 500", "numOfRuns: " + str(args.runs)) + content = content.replace( + "fi_max_multiple: 1", "fi_max_multiple: " + str(args.fault) + ) if args.random_layer: global whiteList @@ -87,14 +129,17 @@ def main(): random_element_key = random.choice(list(layer_count)) random_element_value = random.randint(1, int(layer_count[random_element_key])) - content = content.replace("layerNo=1", "layerNo="+str(random_element_value)) - content = content.replace("layerId=1986948931", "layerId="+str(getONNXId(random_element_key))) - + content = content.replace("layerNo=1", "layerNo=" + str(random_element_value)) + content = content.replace( + "layerId=1986948931", "layerId=" + str(getONNXId(random_element_key)) + ) + else: - content = content.replace("layerNo=4", "layerNo="+str(args.layer)) + content = content.replace("layerNo=4", "layerNo=" + str(args.layer)) + + with open("input.yaml", "w") as f: + f.write(content) - f = open('input.yaml', 'w') - f.write(content) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/executeParallel.py b/tools/executeParallel.py index a64cb97c..adaf37a3 100644 --- a/tools/executeParallel.py +++ b/tools/executeParallel.py @@ -8,6 +8,7 @@ import subprocess import time + ##### MAIN ##### def main(): parser = argparse.ArgumentParser() @@ -56,7 +57,9 @@ def main(): with open("output.txt", "w") as f: print("Launching process for input " + str(i) + "...") # execute runAllInput.sh script in this directory. - p = subprocess.Popen(["bash", "./runAllInputs.sh", str(i)], stdout=f, stderr=f) + p = subprocess.Popen( + ["bash", "./runAllInputs.sh", str(i)], stdout=f, stderr=f + ) # Add directory to list of directories all_dirs.append(new_dir) @@ -71,5 +74,5 @@ def main(): time.sleep(5) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/fineTuneAndGetONNXFiles.py b/tools/fineTuneAndGetONNXFiles.py index 74330bb8..dcff1625 100644 --- a/tools/fineTuneAndGetONNXFiles.py +++ b/tools/fineTuneAndGetONNXFiles.py @@ -8,21 +8,21 @@ from transformers import DataCollatorForLanguageModeling models = [ -"aajrami/bert-sr-base", -"aajrami/bert-sr-medium", -"aajrami/bert-sr-small", -"aajrami/bert-mlm-base", -"aajrami/bert-mlm-medium", -"aajrami/bert-mlm-small", -"aajrami/bert-fc-base", -"aajrami/bert-fc-medium", -"aajrami/bert-fc-small", -"aajrami/bert-ascii-base", -"aajrami/bert-ascii-medium", -"aajrami/bert-ascii-small", -"aajrami/bert-rand-base", -"aajrami/bert-rand-medium", -"aajrami/bert-rand-small" + "aajrami/bert-sr-base", + "aajrami/bert-sr-medium", + "aajrami/bert-sr-small", + "aajrami/bert-mlm-base", + "aajrami/bert-mlm-medium", + "aajrami/bert-mlm-small", + "aajrami/bert-fc-base", + "aajrami/bert-fc-medium", + "aajrami/bert-fc-small", + "aajrami/bert-ascii-base", + "aajrami/bert-ascii-medium", + "aajrami/bert-ascii-small", + "aajrami/bert-rand-base", + "aajrami/bert-rand-medium", + "aajrami/bert-rand-small", ] inputs = [ @@ -30,11 +30,12 @@ "The is the largest organ in the human body", "One of the defining features of (the phylum that sea urchins belong to) is radial symmetry.", "Not my field, but let me offer one possible point of :", - "Inertial reference frames are, by definition, inertial. Rotation is a kind of ." + "Inertial reference frames are, by definition, inertial. Rotation is a kind of .", ] block_size = 128 + def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} @@ -51,9 +52,11 @@ def group_texts(examples): result["labels"] = result["input_ids"].copy() return result + def preprocess_function(examples): return tokenizer([" ".join(x) for x in examples["answers.text"]]) + eli5 = load_dataset("eli5", split="train_asks[:5000]") eli5 = eli5.train_test_split(test_size=0.2) eli5 = eli5.flatten() @@ -72,7 +75,9 @@ def preprocess_function(examples): lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4) tokenizer.pad_token = tokenizer.eos_token - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm_probability=0.15 + ) model = AutoModelForMaskedLM.from_pretrained(modelName).to("cuda") training_args = TrainingArguments( @@ -82,7 +87,7 @@ def preprocess_function(examples): per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=10, - weight_decay=0.01 + weight_decay=0.01, ) trainer = Trainer( @@ -96,25 +101,30 @@ def preprocess_function(examples): trainer.train() - modelName = modelName.replace('/', '-') + 'fine-tuned' + modelName = modelName.replace("/", "-") + "fine-tuned" tokenizer.save_pretrained(modelName) model = model.to("cpu") model.save_pretrained(modelName) # Convert the model to ONNX - os.system("python3 -m transformers.onnx --model=" + modelName + ' --feature=masked-lm ' + str(modelName + '-onnx')) + os.system( + "python3 -m transformers.onnx --model=" + + modelName + + " --feature=masked-lm " + + str(modelName + "-onnx") + ) os.system("rm -rf " + str(modelName)) - os.chdir(str(modelName + '-onnx')) + os.chdir(str(modelName + "-onnx")) tokenizer = AutoTokenizer.from_pretrained(".") # Convert inputs to .pb file for j in range(0, len(inputs)): inp = inputs[j] inp = tokenizer(inp, return_tensors="np") - #pdb.set_trace() - input_ids = numpy_helper.from_array(inp['input_ids']) - attention_mask = numpy_helper.from_array(inp['attention_mask']) + # pdb.set_trace() + input_ids = numpy_helper.from_array(inp["input_ids"]) + attention_mask = numpy_helper.from_array(inp["attention_mask"]) # Convert to torch tensors with open(os.path.join("", f"input{j}_0.pb"), "wb") as f: @@ -123,4 +133,4 @@ def preprocess_function(examples): with open(os.path.join("", f"input{j}_1.pb"), "wb") as f: f.write(attention_mask.SerializeToString()) - os.chdir('..') \ No newline at end of file + os.chdir("..") diff --git a/tools/gatherFIStats.py b/tools/gatherFIStats.py index 691aeb64..b90d0131 100644 --- a/tools/gatherFIStats.py +++ b/tools/gatherFIStats.py @@ -15,6 +15,7 @@ calculate_embedding_distance = False + class FIRun: # Model Input @@ -36,7 +37,18 @@ class FIRun: # FI layer name fi_layer_name = "" - def __init__(self, _model_input, _fi_run, _fi_bit, _fi_layer, _fi_old_val, _fi_new_val, _fi_cycle, _fi_opcode, _fi_layer_name): + def __init__( + self, + _model_input, + _fi_run, + _fi_bit, + _fi_layer, + _fi_old_val, + _fi_new_val, + _fi_cycle, + _fi_opcode, + _fi_layer_name, + ): self.fi_run = _fi_run self.fi_bit = _fi_bit self.fi_layer = _fi_layer @@ -106,13 +118,13 @@ def __init__(self, _model_name, _model_input_num, _fi_stats_dir, _fi_output_path self.parseFIStats() - if self.model_name == 'gpt2': + if self.model_name == "gpt2": self.parseGPT2Output() - elif 'bert' in self.model_name and self.model_name != 'roberta': + elif "bert" in self.model_name and self.model_name != "roberta": self.parseBertOutput() - elif 't5' in self.model_name: + elif "t5" in self.model_name: self.parseT5DecoderOutput() - elif self.model_name == 'roberta': + elif self.model_name == "roberta": self.parseRobertaOutput() else: print("Invalid model name") @@ -130,32 +142,48 @@ def __init__(self, _model_name, _model_input_num, _fi_stats_dir, _fi_output_path # Calculate Embedding distance for CodeBert def calculateEmbeddingDistanceCodeBert(self): from transformers import AutoTokenizer, AutoModel + tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") model = AutoModel.from_pretrained("microsoft/codebert-base") - code_tokens=tokenizer.tokenize(str(self.model_correct_pred)) - tokens=[tokenizer.cls_token]+[tokenizer.sep_token]+code_tokens+[tokenizer.eos_token] - input_ids=tokenizer.convert_tokens_to_ids(tokens) - correct_embeddings = model(torch.tensor(input_ids)[None,:])[0] + code_tokens = tokenizer.tokenize(str(self.model_correct_pred)) + tokens = ( + [tokenizer.cls_token] + + [tokenizer.sep_token] + + code_tokens + + [tokenizer.eos_token] + ) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + correct_embeddings = model(torch.tensor(input_ids)[None, :])[0] for key in self.fi_outputs.keys(): - if self.fi_outputs[key] != self.model_correct_pred and key in self.fi_runs.keys(): + if ( + self.fi_outputs[key] != self.model_correct_pred + and key in self.fi_runs.keys() + ): # Calculate embedding distance - code_tokens=tokenizer.tokenize(str(self.fi_outputs[key])) - tokens=[tokenizer.cls_token]+[tokenizer.sep_token]+code_tokens+[tokenizer.eos_token] - input_ids=tokenizer.convert_tokens_to_ids(tokens) - incorrect_embeddings = model(torch.tensor(input_ids)[None,:])[0] + code_tokens = tokenizer.tokenize(str(self.fi_outputs[key])) + tokens = ( + [tokenizer.cls_token] + + [tokenizer.sep_token] + + code_tokens + + [tokenizer.eos_token] + ) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + incorrect_embeddings = model(torch.tensor(input_ids)[None, :])[0] set_trace() cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) - distance = cos(correct_embeddings[0][0], incorrect_embeddings[0][0]).item() + distance = cos( + correct_embeddings[0][0], incorrect_embeddings[0][0] + ).item() self.fi_outputs_embedding_distance[key] = distance # Calculate Embedding distance def calculateEmbeddingDistance(self): # codebert - #if self.model_name == 'codebert': + # if self.model_name == 'codebert': # g return self.calculateEmbeddingDistanceCodeBert() from flair.data import Sentence @@ -167,7 +195,10 @@ def calculateEmbeddingDistance(self): # Calculate embedding distance for key in self.fi_outputs.keys(): - if self.fi_outputs[key] != self.model_correct_pred and key in self.fi_runs.keys(): + if ( + self.fi_outputs[key] != self.model_correct_pred + and key in self.fi_runs.keys() + ): # Calculate embedding distance sentence1 = Sentence(self.model_correct_pred) sentence2 = Sentence(self.fi_outputs[key]) @@ -183,11 +214,13 @@ def calculateEmbeddingDistance(self): def identifyLang(self, text): # Check if model file exists, otherwise download it. - if not os.path.isfile('lid.176.ftz'): + if not os.path.isfile("lid.176.ftz"): print("Downloading fasttext model") - os.system("wget https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz") + os.system( + "wget https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz" + ) - model = fasttext.load_model('lid.176.ftz') + model = fasttext.load_model("lid.176.ftz") predictions = model.predict(text, k=1) return predictions[0][0].split("__label__")[1] @@ -195,10 +228,11 @@ def identifyLang(self, text): def mean_confidence_interval(self, data, confidence=0.95): import numpy as np import scipy.stats + a = 1.0 * np.array(data) n = len(a) m, se = np.mean(a), scipy.stats.sem(a) - h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) + h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1) return m, h # Analyze the effect of faults on the model @@ -211,15 +245,23 @@ def analyzeFaults(self): # Assuming that fi_runs and fi_outputs share the smae keys. if len(self.fi_runs.keys()) != len(self.fi_outputs.keys()): print("----------------------------------") - print("WARNING: Mismatch between number of FI runs and number of FI outputs") + print( + "WARNING: Mismatch between number of FI runs and number of FI outputs" + ) k = [i for i in self.fi_outputs.keys() if i not in self.fi_runs.keys()] print("\t Keys not found: " + str(k)) - print("\t Prediction corresponding to these keys: " + str([self.fi_outputs[i] for i in k])) + print( + "\t Prediction corresponding to these keys: " + + str([self.fi_outputs[i] for i in k]) + ) print("\t Correct prediction: " + str(self.model_correct_pred)) print("----------------------------------") for key in self.fi_outputs.keys(): - if self.fi_outputs[key] != self.model_correct_pred and key in self.fi_runs.keys(): + if ( + self.fi_outputs[key] != self.model_correct_pred + and key in self.fi_runs.keys() + ): all_fi_bits.append(self.fi_runs[key].fi_bit) all_fi_layers.append(self.fi_runs[key].fi_layer) all_fi_opcode.append(self.fi_runs[key].fi_opcode) @@ -235,16 +277,25 @@ def analyzeFaults(self): # Calculate SDC rate self.sdc_rate = len(self.all_sdc) / self.fi_total_runs - self.confidence_scores_95 = 1.96 * math.sqrt(self.sdc_rate * (1 - self.sdc_rate) / self.fi_total_runs) - self.confidence_scores_99 = 2.58 * math.sqrt(self.sdc_rate * (1 - self.sdc_rate) / self.fi_total_runs) + self.confidence_scores_95 = 1.96 * math.sqrt( + self.sdc_rate * (1 - self.sdc_rate) / self.fi_total_runs + ) + self.confidence_scores_99 = 2.58 * math.sqrt( + self.sdc_rate * (1 - self.sdc_rate) / self.fi_total_runs + ) # Calculate Average Cosine Similarity and 95% confidence scores - distances = [self.fi_outputs_embedding_distance[key] for key in self.fi_outputs_embedding_distance.keys()] - self.avg_cosine_similarity, self.CI_cosine_similarity = self.mean_confidence_interval(distances, confidence=0.95) + distances = [ + self.fi_outputs_embedding_distance[key] + for key in self.fi_outputs_embedding_distance.keys() + ] + self.avg_cosine_similarity, self.CI_cosine_similarity = ( + self.mean_confidence_interval(distances, confidence=0.95) + ) # Parse BioBert output def parseBertOutput(self): - assert 'bert' in self.model_name + assert "bert" in self.model_name # Check if PredResult.txt exists if not os.path.isfile(self.fi_output_dir + "/onnx-pred/onnx-pred.txt"): @@ -257,19 +308,19 @@ def parseBertOutput(self): content = content.split("Run")[1:] # Make sure that the number of FI stats is same as the number of predictions. - #if len(content) != self.fi_total_runs and "Mismatch between number of FI stats and number of predictions": + # if len(content) != self.fi_total_runs and "Mismatch between number of FI stats and number of predictions": # set_trace() all_predictions = [] for run in content: try: - run_num = int(run.split(': ')[-1].split(' ', 1)[0]) - run_pred = str(run.split('\n')[1]) + run_num = int(run.split(": ")[-1].split(" ", 1)[0]) + run_pred = str(run.split("\n")[1]) all_predictions.append(run_pred) # Store the prediction self.fi_outputs[run_num] = run_pred - except: + except Exception: set_trace() print("Error parsing run: ", run) continue @@ -280,7 +331,7 @@ def parseBertOutput(self): # Parse RoBerta output def parseRobertaOutput(self): - assert self.model_name == 'roberta' + assert self.model_name == "roberta" # Check if PredResult.txt exists if not os.path.isfile(self.fi_output_dir + "/PredResult.txt"): @@ -295,13 +346,13 @@ def parseRobertaOutput(self): all_predictions = [] for run in content: try: - run_num = int(run.split('#')[1].split(' ')[0]) - run_pred = str(run.split(':')[-1].replace(' ','')) + run_num = int(run.split("#")[1].split(" ")[0]) + run_pred = str(run.split(":")[-1].replace(" ", "")) all_predictions.append(run_pred) # Store the prediction self.fi_outputs[run_num] = run_pred - except: + except Exception: print("Error parsing run: ", run) continue @@ -311,7 +362,7 @@ def parseRobertaOutput(self): # Parse GPT2 output def parseGPT2Output(self): - assert self.model_name == 'gpt2' + assert self.model_name == "gpt2" # Check if PredResult.txt exists if not os.path.isfile(self.fi_output_dir + "/PredResult.txt"): @@ -327,13 +378,13 @@ def parseGPT2Output(self): for run in content: try: - run_num = int(run.split('#')[1].split(' ')[0]) - run_pred = str(run.replace(' ','').split('\n')[1].split(':')[0]) + run_num = int(run.split("#")[1].split(" ")[0]) + run_pred = str(run.replace(" ", "").split("\n")[1].split(":")[0]) all_predictions.append(run_pred) # Store the prediction self.fi_outputs[run_num] = run_pred - except: + except Exception: print("Error parsing run: ", run) continue @@ -343,7 +394,7 @@ def parseGPT2Output(self): # Parse T5 decoder output def parseT5DecoderOutput(self): - assert 't5' in self.model_name + assert "t5" in self.model_name # Check if PredResult.txt exists if not os.path.isfile(self.fi_output_dir + "/PredResult.txt"): @@ -359,13 +410,13 @@ def parseT5DecoderOutput(self): for run in content: try: - run_num = int(run.split('#')[1].split(' ')[0]) - run_pred = str(run.replace('\n', '').split(':')[-1]) + run_num = int(run.split("#")[1].split(" ")[0]) + run_pred = str(run.replace("\n", "").split(":")[-1]) all_predictions.append(run_pred) # Store the prediction self.fi_outputs[run_num] = run_pred - except: + except Exception: print("Error parsing run: ", run) continue @@ -379,14 +430,14 @@ def parseFIStats(self): for subdir, dirs, files in os.walk(self.fi_stats_dir): for file in files: # Check if the file is a FI stats file - if 'llfi' in file and 'swp' not in file: + if "llfi" in file and "swp" not in file: self.fi_total_runs = self.fi_total_runs + 1 # Parse the FI stats file - fi_num = int(str(file).split('-')[-1].split('.')[0]) + fi_num = int(str(file).split("-")[-1].split(".")[0]) # Open and parse file - with open(os.path.abspath(subdir) + "/" + file, 'r') as f: + with open(os.path.abspath(subdir) + "/" + file, "r") as f: content = f.readlines() - content = content[0].replace(' ', '').split(',') + content = content[0].replace(" ", "").split(",") fi_bit = -1 fi_layer = -1 fi_old_val = -1 @@ -396,25 +447,43 @@ def parseFIStats(self): fi_opcode = "" for fields in content: - first, second = fields.split('=') - if first == 'fi_cycle': + first, second = fields.split("=") + if first == "fi_cycle": fi_cycle = int(second) - elif first == 'fi_bit': + elif first == "fi_bit": fi_bit = int(second) - elif first == 'ml_layer_number': + elif first == "ml_layer_number": fi_layer = int(second) - elif first == 'oldHex': + elif first == "oldHex": fi_old_val = int(second, 16) - elif first == 'newHex': + elif first == "newHex": fi_new_val = int(second, 16) - elif first == 'opcode': + elif first == "opcode": fi_opcode = second - elif first == 'ml_layer_name': + elif first == "ml_layer_name": fi_layer_name = second # Validate inputs - assert fi_bit != -1 and fi_layer != -1 and fi_old_val != -1 and fi_new_val != -1 and fi_cycle != -1 and fi_layer_name != "" and fi_opcode != "" - runObj = FIRun(self.model_input_num, fi_num, fi_bit, fi_layer, fi_old_val, fi_new_val, fi_cycle, fi_opcode, fi_layer_name) + assert ( + fi_bit != -1 + and fi_layer != -1 + and fi_old_val != -1 + and fi_new_val != -1 + and fi_cycle != -1 + and fi_layer_name != "" + and fi_opcode != "" + ) + runObj = FIRun( + self.model_input_num, + fi_num, + fi_bit, + fi_layer, + fi_old_val, + fi_new_val, + fi_cycle, + fi_opcode, + fi_layer_name, + ) # Add to dictionary self.fi_runs[fi_num] = runObj @@ -425,32 +494,78 @@ def printStats(self, args): print("------- Input: " + str(self.model_input_num) + " -------") print("SDC Rate: " + str(len(self.all_sdc) / self.fi_total_runs)) print("SDC Rate Confidence Scores: 95%: " + str(self.confidence_scores_95)) - print("Total SDCs: " + str(len(self.all_sdc)) + " Total runs: " + str(self.fi_total_runs)) + print( + "Total SDCs: " + + str(len(self.all_sdc)) + + " Total runs: " + + str(self.fi_total_runs) + ) if args.calc_cs: print("Average Cosine Similarity: " + str(self.avg_cosine_similarity)) - print("Cosine Similarity Confidence Scores: 95%: " + str(self.CI_cosine_similarity)) + print( + "Cosine Similarity Confidence Scores: 95%: " + + str(self.CI_cosine_similarity) + ) if args.verbose: # Print all SDCs - print("All SDCs: " + str([[key, self.fi_outputs[key]] for key in self.all_sdc])) + print( + "All SDCs: " + + str([[key, self.fi_outputs[key]] for key in self.all_sdc]) + ) print("Correct prediction: " + str(self.model_correct_pred)) print("FI_Layers: " + str(self.count_fi_layers)) print("FI_Bits: " + str(self.count_fi_bits)) - print("Corrupted vals: " + str([[key, self.fi_runs[key].fi_new_val] for key in self.all_sdc])) - print("FI Layers: " + str([[key, self.fi_runs[key].fi_layer] for key in self.all_sdc])) + print( + "Corrupted vals: " + + str([[key, self.fi_runs[key].fi_new_val] for key in self.all_sdc]) + ) + print( + "FI Layers: " + + str([[key, self.fi_runs[key].fi_layer] for key in self.all_sdc]) + ) + ##### MAIN ##### def main(): parser = argparse.ArgumentParser() parser.add_argument("results", help="Path to results directory") - parser.add_argument("model_name", help="Name of the model. Can be one of: biobert, codebert, gpt2, t5-encoder, t5-decoder, roberta") - parser.add_argument("--calculate-embeddings", action="store_true", default=False, help="Calculate embedding distance between correct and incorrect predictions. \t Default: False") - parser.add_argument("--sdc-rates", action="store_true", default=False, help="Print the SDC rate \t Default: False") - parser.add_argument("--calc-cs", action="store_true", default=False, help="Print the Cosine Similarity rate \t Default: False") - parser.add_argument("--verbose", action="store_true", default=False, help="Print verbose output \t Default: False") + parser.add_argument( + "model_name", + help="Name of the model. Can be one of: biobert, codebert, gpt2, t5-encoder, t5-decoder, roberta", + ) + parser.add_argument( + "--calculate-embeddings", + action="store_true", + default=False, + help="Calculate embedding distance between correct and incorrect predictions. \t Default: False", + ) + parser.add_argument( + "--sdc-rates", + action="store_true", + default=False, + help="Print the SDC rate \t Default: False", + ) + parser.add_argument( + "--calc-cs", + action="store_true", + default=False, + help="Print the Cosine Similarity rate \t Default: False", + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Print verbose output \t Default: False", + ) # Export to CSV? - parser.add_argument("--export-csv", action="store_true", default=False, help="Export to CSV \t Default: False") + parser.add_argument( + "--export-csv", + action="store_true", + default=False, + help="Export to CSV \t Default: False", + ) args = parser.parse_args() @@ -462,23 +577,30 @@ def main(): print("Results directory does not exist") return - if args.model_name not in ["biobert", "codebert", "gpt2", "t5-encoder", "t5-decoder", "roberta"]: + if args.model_name not in [ + "biobert", + "codebert", + "gpt2", + "t5-encoder", + "t5-decoder", + "roberta", + ]: print("Invalid model name") return stats_for_inputs = {} # get sub-directory of results directory - for dir in [ f for f in os.scandir(args.results) if f.is_dir() ]: + for dir in [f for f in os.scandir(args.results) if f.is_dir()]: input_num = int(dir.name) fi_stats_dir = "" model_output_dir = "" # get sub-directory of input directory - for dirs1 in [ f for f in os.scandir(dir.path) if f.is_dir() ]: + for dirs1 in [f for f in os.scandir(dir.path) if f.is_dir()]: directory_name = str(dirs1.name) - if directory_name not in ['prediction', 'out']: - assert 'llfi' in directory_name + if directory_name not in ["prediction", "out"]: + assert "llfi" in directory_name # This is LLFI's FI stats directory fi_stats_dir = os.path.abspath(dirs1.path) else: @@ -489,7 +611,9 @@ def main(): assert os.path.isdir(fi_stats_dir) and os.path.isdir(model_output_dir) assert input_num >= 0 and input_num < 10 - NLPModelObj = NLPModel(args.model_name, input_num, fi_stats_dir, model_output_dir) + NLPModelObj = NLPModel( + args.model_name, input_num, fi_stats_dir, model_output_dir + ) stats_for_inputs[input_num] = NLPModelObj # Print stats @@ -500,13 +624,32 @@ def main(): if args.export_csv: # Export overall statistics - header = ["Input", "SDC Rate", "SDC Rate Confidence Scores", "Total SDCs", "Total runs", "Average Cosine Similarity", "Cosine Similarity Confidence Scores"] + header = [ + "Input", + "SDC Rate", + "SDC Rate Confidence Scores", + "Total SDCs", + "Total runs", + "Average Cosine Similarity", + "Cosine Similarity Confidence Scores", + ] rows = [] for i in sorted(stats_for_inputs.keys()): - rows.append([i, stats_for_inputs[i].sdc_rate, stats_for_inputs[i].confidence_scores_95, len(stats_for_inputs[i].all_sdc), stats_for_inputs[i].fi_total_runs, stats_for_inputs[i].avg_cosine_similarity, stats_for_inputs[i].CI_cosine_similarity]) + rows.append( + [ + i, + stats_for_inputs[i].sdc_rate, + stats_for_inputs[i].confidence_scores_95, + len(stats_for_inputs[i].all_sdc), + stats_for_inputs[i].fi_total_runs, + stats_for_inputs[i].avg_cosine_similarity, + stats_for_inputs[i].CI_cosine_similarity, + ] + ) import csv - with open(str(args.model_name) + '_results.csv', 'w') as f: + + with open(str(args.model_name) + "_results.csv", "w") as f: write = csv.writer(f) # Export Model Name and leave a row empty write.writerow([args.model_name]) @@ -515,7 +658,21 @@ def main(): write.writerows(rows) # Export SDCs for each input - header = ["S. No.", "SDC", "", "Correct Prediction", "", "fi_run", "fi_bit", "fi_layer", "fi_old_val", "fi_new_val", "fi_cycle", "fi_opcode", "fi_layer_name"] + header = [ + "S. No.", + "SDC", + "", + "Correct Prediction", + "", + "fi_run", + "fi_bit", + "fi_layer", + "fi_old_val", + "fi_new_val", + "fi_cycle", + "fi_opcode", + "fi_layer_name", + ] rows = [] for i in sorted(stats_for_inputs.keys()): model_name = args.model_name @@ -535,11 +692,27 @@ def main(): fi_opcode = stats_for_inputs[i].fi_runs[sdc].fi_opcode fi_layer_name = stats_for_inputs[i].fi_runs[sdc].fi_layer_name - rows.append([counter, wrong_pred, "", stats_for_inputs[i].model_correct_pred, "", fi_run, fi_bit, fi_layer, fi_old_val, fi_new_val, fi_cycle, fi_opcode, fi_layer_name]) + rows.append( + [ + counter, + wrong_pred, + "", + stats_for_inputs[i].model_correct_pred, + "", + fi_run, + fi_bit, + fi_layer, + fi_old_val, + fi_new_val, + fi_cycle, + fi_opcode, + fi_layer_name, + ] + ) rows.append([]) # Open CSV file - with open(str(args.model_name) +"_all_inputs_results.csv", 'w') as f: + with open(str(args.model_name) + "_all_inputs_results.csv", "w") as f: write = csv.writer(f) # Export Model Name and leave a row empty write.writerow([args.model_name]) @@ -549,5 +722,5 @@ def main(): write.writerow(row) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/outputONNXGraph.py b/tools/outputONNXGraph.py index 05bee446..69ebc4c4 100644 --- a/tools/outputONNXGraph.py +++ b/tools/outputONNXGraph.py @@ -1,28 +1,30 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import onnx import pydot import os -import pdb +import sys + +# ---------------------------------------------------------------------- -#---------------------------------------------------------------------- def get_tensor_shape(node): # returns the shape of the tensor given an ONNX node - return tuple( int(item.dim_value) for item in node.type.tensor_type.shape.dim ) + return tuple(int(item.dim_value) for item in node.type.tensor_type.shape.dim) + +# ---------------------------------------------------------------------- -#---------------------------------------------------------------------- -def makeDot(model, addIndex = False): +def makeDot(model, addIndex=False): ingraph = model.graph # see e.g. https://pythonhaven.wordpress.com/2009/12/09/generating_graphs_with_pydot/ - outgraph = pydot.Dot(graph_type='digraph') + outgraph = pydot.Dot(graph_type="digraph") - #---------- + # ---------- # Note that in the onnx model (at least when created # from pytorch) the computational boxes do not have names # but rather the connections between them @@ -32,23 +34,23 @@ def makeDot(model, addIndex = False): # (which defines the value) but can have multiple # inputs connected. We draw an edge from each of the # inputs to the single output - #---------- + # ---------- # this maps from an edge / netlist name to the node # which provides the output with this name nameToNodeOfOutput = {} - #---------- + # ---------- # find input nodes which have initializers # these are not real inputs but rather weights # learned during training - #---------- + # ---------- - initializerNames = set([ node.name for node in ingraph.initializer ]) + initializerNames = set([node.name for node in ingraph.initializer]) - #---------- + # ---------- # add boxes for the input nodes - #---------- + # ---------- for index, node in enumerate(ingraph.input): # note that (at least when generated from pytorch) # things like convolution matrix weights @@ -58,23 +60,23 @@ def makeDot(model, addIndex = False): # this is a weight node, skip it continue - labels = [ "input " + node.name, - get_tensor_shape(node) - ] + labels = ["input " + node.name, get_tensor_shape(node)] gn = pydot.Node( "in%d" % (index + 1), - label = "\n".join([ str(x) for x in labels ]), - shape = 'record', style = 'filled', - fillcolor = '#A2CECE') + label="\n".join([str(x) for x in labels]), + shape="record", + style="filled", + fillcolor="#A2CECE", + ) outgraph.add_node(gn) assert node.name not in nameToNodeOfOutput nameToNodeOfOutput[node.name] = gn - #---------- + # ---------- # add boxes for the output nodes - #---------- + # ---------- outputGraphNodes = [] @@ -83,80 +85,79 @@ def makeDot(model, addIndex = False): # things like convolution matrix weights # can be considered as inputs - labels = [ "output " + node.name, - get_tensor_shape(node) - ] - + labels = ["output " + node.name, get_tensor_shape(node)] gn = pydot.Node( "out%d" % (index + 1), - label = "\n".join([ str(x) for x in labels ]), - shape = 'record') + label="\n".join([str(x) for x in labels]), + shape="record", + ) outgraph.add_node(gn) outputGraphNodes.append(gn) - #---------- + # ---------- # add boxes for the computational nodes # and the corresponding edges - #---------- + # ---------- for index, node in enumerate(ingraph.node): # note that these nodes most of the time # do not have a name, i.e. node.name is the empty string - labels = [ node.op_type, str(node.input), str(node.output), - ] + labels = [ + node.op_type, + str(node.input), + str(node.output), + ] - #---------- + # ---------- # this should go into some kind of decorator - #---------- - if node.op_type in ('Conv', 'MaxPool'): + # ---------- + if node.op_type in ("Conv", "MaxPool"): # TODO: get number of filter banks for attr in node.attribute: # TODO: we should guarantee an ordering of the labels - if attr.name == 'kernel_shape': + if attr.name == "kernel_shape": shape = tuple(int(x) for x in attr.ints) labels.append("kernel size " + str(shape)) - elif attr.name == 'strides': + elif attr.name == "strides": shape = tuple(int(x) for x in attr.ints) - if shape != (1,1): + if shape != (1, 1): labels.append("strides " + str(shape)) - - elif node.op_type == 'Reshape': + elif node.op_type == "Reshape": for attr in node.attribute: # TODO: we should guarantee an ordering of the labels - if attr.name == 'shape': + if attr.name == "shape": shape = tuple(int(x) for x in attr.ints) labels.append("shape " + str(shape)) - - elif node.op_type == 'Dropout': + elif node.op_type == "Dropout": for attr in node.attribute: # TODO: we should guarantee an ordering of the labels - if attr.name == 'ratio': + if attr.name == "ratio": labels.append("p=" + str(attr.f)) - - #---------- + # ---------- if addIndex: # for debugging labels.append("(index = %d)" % index) - # create a graphviz node gn = pydot.Node( - "n%d" % (index + 1), - label = "\n".join([ str(x) for x in labels ]), - shape = 'record', style = 'filled') + "n%d" % (index + 1), + label="\n".join([str(x) for x in labels]), + shape="record", + style="filled", + ) outgraph.add_node(gn) # add outputs first @@ -176,11 +177,11 @@ def makeDot(model, addIndex = False): # get the pydot node we have to connect to inputNode = nameToNodeOfOutput[inputName] - outgraph.add_edge(pydot.Edge(src = inputNode, dst = gn)) + outgraph.add_edge(pydot.Edge(src=inputNode, dst=gn)) - #---------- + # ---------- # add edges of output nodes to their sources - #---------- + # ---------- # note that the output nodes do not have an input @@ -189,16 +190,15 @@ def makeDot(model, addIndex = False): # get the pydot node we have to connect to inputNode = nameToNodeOfOutput[node.name] - outgraph.add_edge(pydot.Edge(src = inputNode, dst = graphNode)) - + outgraph.add_edge(pydot.Edge(src=inputNode, dst=graphNode)) return outgraph -#---------------------------------------------------------------------- -if __name__ == '__main__': +# ---------------------------------------------------------------------- + +if __name__ == "__main__": - import sys ARGV = sys.argv[1:] assert len(ARGV) == 2, "usage: in.onnx output.{dot,pdf,...}" @@ -206,26 +206,27 @@ def makeDot(model, addIndex = False): inputFname, outputFname = ARGV if os.path.exists(outputFname): - print >> sys.stderr,"output file " + outputFname + " exists already, refusing to overwrite it" + print( + "output file " + outputFname + " exists already, refusing to overwrite it", + file=sys.stderr, + ) sys.exit(1) - # infer output format from suffix - outputFormat = outputFname.split('.')[-1].lower() + outputFormat = outputFname.split(".")[-1].lower() if inputFname.endswith(".gz"): import gzip - fin = gzip.GzipFile(inputFname) - else: - fin = open(inputFname) - model = onnx.load(inputFname) + with gzip.GzipFile(inputFname) as fin: + model = onnx.load(fin) + else: + model = onnx.load(inputFname) outgraph = makeDot(model) - #---------- + # ---------- # write the graph out - #---------- - - outgraph.write(outputFname, format = outputFormat) + # ---------- + outgraph.write(outputFname, format=outputFormat) diff --git a/tools/tracediff.py b/tools/tracediff.py index 7d89e9c6..2e611a5f 100755 --- a/tools/tracediff.py +++ b/tools/tracediff.py @@ -1,6 +1,6 @@ #! /usr/bin/env python3 -#traceDiff.py +# traceDiff.py # Author: Sam Coulter # This python script is part of the greater LLFI system. # This script will examine two tracing output files generated by running a program after @@ -17,68 +17,71 @@ prog = os.path.basename(sys.argv[0]) -def traceDiff(argv, output = 0): - #save stdout so we can redirect it without mangling other python scripts - oldSTDOut = sys.stdout - - # TODO: rewrite the command line argument of the script - if output != 0: - sys.stdout = open(output, "wb") - if (len(argv) != 3): - print("ERROR: running option: %(prog)s " % {'prog': prog}, file=sys.stderr) - exit(1) - - goldFile = open(argv[1], 'r') - goldTrace = goldFile.read() - goldFile.close() - - faultyFile = open(argv[2], 'r') - faultyTrace = faultyFile.read() - faultyFile.close() - - goldTraceLines = goldTrace.split("\n") - faultyTraceLines = faultyTrace.split("\n") - - #Examine Header of Trace File - header = faultyTraceLines[0].split(' ') - for i in range(0, len(header) - 1): - keyword = header[i] - if keyword == "#TraceStartInstNumber:": - #Remove traces from golden trace that happened before fault injection point - faultyTraceStartPoint = int(header[i+1]) - faultyTraceLines.pop(0) - for i in range (0,faultyTraceStartPoint-1): - goldTraceLines.pop(0) - - #record and report the fault injected line - goldInjectedLine = diffLine(goldTraceLines[0]) - faultInjectedLine = diffLine(faultyTraceLines[0]) - diffID = goldInjectedLine.ID - print("#FaultReport") - print("1 @", faultyTraceStartPoint) - print(goldInjectedLine.raw, "/", faultInjectedLine.Value) - - #remove the fault injected lines - goldTraceLines.pop(0) - faultyTraceLines.pop(0) - - for line in goldTraceLines: - if line == "": - goldTraceLines.remove(line) - - for line in faultyTraceLines: - if line == "": - faultyTraceLines.remove(line) - - lenGT = len(goldTraceLines) - 1 - lenFT = len(faultyTraceLines) - 1 - - i = 0 - - if (lenGT < 0 and lenFT < 0): - return 0 - - ''' + +def traceDiff(argv, output=0): + # save stdout so we can redirect it without mangling other python scripts + oldSTDOut = sys.stdout + + # TODO: rewrite the command line argument of the script + if output != 0: + sys.stdout = open(output, "wb") # intentional stdout redirect + if len(argv) != 3: + print( + "ERROR: running option: %(prog)s " + % {"prog": prog}, + file=sys.stderr, + ) + sys.exit(1) + + with open(argv[1], "r") as goldFile: + goldTrace = goldFile.read() + + with open(argv[2], "r") as faultyFile: + faultyTrace = faultyFile.read() + + goldTraceLines = goldTrace.split("\n") + faultyTraceLines = faultyTrace.split("\n") + + # Examine Header of Trace File + header = faultyTraceLines[0].split(" ") + for i in range(0, len(header) - 1): + keyword = header[i] + if keyword == "#TraceStartInstNumber:": + # Remove traces from golden trace that happened before fault injection point + faultyTraceStartPoint = int(header[i + 1]) + faultyTraceLines.pop(0) + for i in range(0, faultyTraceStartPoint - 1): + goldTraceLines.pop(0) + + # record and report the fault injected line + goldInjectedLine = diffLine(goldTraceLines[0]) + faultInjectedLine = diffLine(faultyTraceLines[0]) + diffID = goldInjectedLine.ID + print("#FaultReport") + print("1 @", faultyTraceStartPoint) + print(goldInjectedLine.raw, "/", faultInjectedLine.Value) + + # remove the fault injected lines + goldTraceLines.pop(0) + faultyTraceLines.pop(0) + + for line in goldTraceLines: + if line == "": + goldTraceLines.remove(line) + + for line in faultyTraceLines: + if line == "": + faultyTraceLines.remove(line) + + lenGT = len(goldTraceLines) - 1 + lenFT = len(faultyTraceLines) - 1 + + i = 0 + + if lenGT < 0 and lenFT < 0: + return 0 + + """ while (faultyTraceLines[lenFT-i] == goldTraceLines[lenGT-i]): postDiffID = diffLine(goldTraceLines[lenGT-i]).ID faultyTraceLines.pop(lenFT-i) @@ -86,19 +89,25 @@ def traceDiff(argv, output = 0): i = i + 1 if lenFT-i < 0 or lenGT-i < 0: break - ''' - - report = diffReport(goldTraceLines, faultyTraceLines, faultyTraceStartPoint, diffID) - if report != None: - report.printSummary() - - #restore stdout - sys.stdout = oldSTDOut - -if (__name__ == "__main__"): - if len(sys.argv) >= 2 and (sys.argv[1] == '-h' or sys.argv[1] == '--help'): - print(("%(prog)s compares the golden program trace and fault injection program trace and summarizes the differences\n\n" - "running option: %(prog)s " %{"prog": prog}), file=sys.stderr) - else: - traceDiff(sys.argv) - + """ + + report = diffReport(goldTraceLines, faultyTraceLines, faultyTraceStartPoint, diffID) + if report != None: + report.printSummary() + + # restore stdout + sys.stdout = oldSTDOut + + +if __name__ == "__main__": + if len(sys.argv) >= 2 and (sys.argv[1] == "-h" or sys.argv[1] == "--help"): + print( + ( + "%(prog)s compares the golden program trace and fault injection program trace and summarizes the differences\n\n" + "running option: %(prog)s " + % {"prog": prog} + ), + file=sys.stderr, + ) + else: + traceDiff(sys.argv) diff --git a/tools/traceontograph.py b/tools/traceontograph.py index 4564f341..e1377e23 100755 --- a/tools/traceontograph.py +++ b/tools/traceontograph.py @@ -1,11 +1,11 @@ #! /usr/bin/env python3 -#traceOntoGraph.py -#Author: Sam Coulter -#This script will take 1 trace Union file as input, and 1 llfi program dot graph -#it will apply the tracing information to the graph so that fault injected instructions -#are bordered in red, and fault affected instructions have a yellow fill -#Usage: +# traceOntoGraph.py +# Author: Sam Coulter +# This script will take 1 trace Union file as input, and 1 llfi program dot graph +# it will apply the tracing information to the graph so that fault injected instructions +# are bordered in red, and fault affected instructions have a yellow fill +# Usage: # ./traceOntoGraph.py myTraceReportFile myProgramGraph.dot > myNewGraph.dot import sys @@ -18,50 +18,75 @@ prog = os.path.basename(sys.argv[0]) + def traceOntoGraph(traceFile, graphFile, output=0): - #save stdout so we can redirect it without mangling other python scripts - oldSTDOut = sys.stdout - if output != 0: - sys.stdout = open(output, "wb") + # save stdout so we can redirect it without mangling other python scripts + oldSTDOut = sys.stdout + if output != 0: + sys.stdout = open(output, "wb") - faultReports = parseFaultReportsfromFile(traceFile) + faultReports = parseFaultReportsfromFile(traceFile) - graphF = open(graphFile, 'r') - graphLines = graphF.readlines() - graphF.close() + with open(graphFile, "r") as graphF: + graphLines = graphF.readlines() - for rep in faultReports: - affectedInsts = rep.getAffectedSet() - affectedEdges = rep.getAffectedEdgesSet() - for i in range(len(graphLines)): - if ("llfiID_" + str(rep.faultID) + " [shape") in graphLines[i]: - graphLines[i] = graphLines[i][:-3] - graphLines[i] = graphLines[i] + ", style=\"filled\", fillcolor=\""+FAULT_INJECTED_BORDER_COLOR +\ - "\"];\n" - for x in affectedInsts: - if ("llfiID_" + str(x) + " [shape") in graphLines[i]: - graphLines[i] = graphLines[i][:-3] - graphLines[i] = graphLines[i] + ", style=\"filled\", fillcolor=\""+AFFECTED_FILL_COLOR+\ - "\"];\n" - for (s, e) in affectedEdges: - summation = sum(1 for line in graphLines if ("llfiID_" + str(s) + " -> " + "llfiID_" in line) and "blue" not in line) - if summation == 2: - if ("llfiID_" + str(s) + " -> " + "llfiID_" + str(e)+";") in graphLines[i]: - graphLines[i] = graphLines[i][:-2] - graphLines[i] = graphLines[i] + " [color=\"red\"];\n" + for rep in faultReports: + affectedInsts = rep.getAffectedSet() + affectedEdges = rep.getAffectedEdgesSet() + for i in range(len(graphLines)): + if ("llfiID_" + str(rep.faultID) + " [shape") in graphLines[i]: + graphLines[i] = graphLines[i][:-3] + graphLines[i] = ( + graphLines[i] + + ', style="filled", fillcolor="' + + FAULT_INJECTED_BORDER_COLOR + + '"];\n' + ) + for x in affectedInsts: + if ("llfiID_" + str(x) + " [shape") in graphLines[i]: + graphLines[i] = graphLines[i][:-3] + graphLines[i] = ( + graphLines[i] + + ', style="filled", fillcolor="' + + AFFECTED_FILL_COLOR + + '"];\n' + ) + for s, e in affectedEdges: + summation = sum( + 1 + for line in graphLines + if ("llfiID_" + str(s) + " -> " + "llfiID_" in line) + and "blue" not in line + ) + if summation == 2: + if ( + "llfiID_" + str(s) + " -> " + "llfiID_" + str(e) + ";" + ) in graphLines[i]: + graphLines[i] = graphLines[i][:-2] + graphLines[i] = graphLines[i] + ' [color="red"];\n' + print("".join(graphLines)) - print(''.join(graphLines)) + # restore stdout + sys.stdout = oldSTDOut - #restore stdout - sys.stdout = oldSTDOut if __name__ == "__main__": - if len(sys.argv) >= 2 and (sys.argv[1] == '-h' or sys.argv[1] == '--help'): - print(("%(prog)s applies the program trace difference summary to program static control-data-flow graph to visualize it\n\n" - "running option: %(prog)s " %{"prog": prog}), file=sys.stderr) - elif len(sys.argv) >= 3: - traceOntoGraph(sys.argv[1], sys.argv[2]) - else: - print("Error: running option: %(prog)s " %{"prog": prog}, file=sys.stderr) - exit(1) \ No newline at end of file + if len(sys.argv) >= 2 and (sys.argv[1] == "-h" or sys.argv[1] == "--help"): + print( + ( + "%(prog)s applies the program trace difference summary to program static control-data-flow graph to visualize it\n\n" + "running option: %(prog)s " + % {"prog": prog} + ), + file=sys.stderr, + ) + elif len(sys.argv) >= 3: + traceOntoGraph(sys.argv[1], sys.argv[2]) + else: + print( + "Error: running option: %(prog)s " + % {"prog": prog}, + file=sys.stderr, + ) + sys.exit(1) diff --git a/tools/tracetodot.py b/tools/tracetodot.py index e5bc46d0..fdf3ed1d 100755 --- a/tools/tracetodot.py +++ b/tools/tracetodot.py @@ -8,7 +8,6 @@ # Output: Generate trace different report files and its .dot files to the folder trace_report_output - """ %(prog)s needs to be called in the folder that contains the llfi trace files (e.g. /llfi_stat_output) @@ -23,7 +22,6 @@ """ - import sys, os import subprocess import shlex @@ -31,140 +29,123 @@ prog = os.path.basename(sys.argv[0]) - def parseArgs(args): - argid = 0 - while argid < len(args): - arg = args[argid] - if arg.startswith("-"): - if arg == "--help" or arg == "-h": - usage() - else: - usage("Invalid argument: " + arg) - argid += 1 - - -def usage(msg = None): - retval = 0 - if msg is not None: - retval = 1 - msg = "ERROR: " + msg - print(msg, file=sys.stderr) - print(__doc__ % globals(), file=sys.stderr) - sys.exit(retval) + argid = 0 + while argid < len(args): + arg = args[argid] + if arg.startswith("-"): + if arg == "--help" or arg == "-h": + usage() + else: + usage("Invalid argument: " + arg) + argid += 1 + + +def usage(msg=None): + retval = 0 + if msg is not None: + retval = 1 + msg = "ERROR: " + msg + print(msg, file=sys.stderr) + print(__doc__ % globals(), file=sys.stderr) + sys.exit(retval) def findPath(): - global currentpath, scriptdir - - currentpath = os.getcwd() - #print (currentpath) + global currentpath, scriptdir - scriptdir = os.path.dirname(os.path.abspath(__file__)) + currentpath = os.getcwd() + # print (currentpath) + scriptdir = os.path.dirname(os.path.abspath(__file__)) def makeTraceOutputFolder(): - global traceOutputFolder, goldenTraceFilePath - traceOutputFolder = os.path.abspath(os.path.join(currentpath, "../trace_report_output")) - #print (traceOutputFolder) - goldenTraceFilePath = os.path.abspath(os.path.join(currentpath, "../baseline/llfi.stat.trace.prof.txt")) - if not os.path.exists(traceOutputFolder): - os.makedirs(traceOutputFolder) - else: - # Remove the contents in traceOutputFolder - for f in os.listdir(traceOutputFolder): - file_path = os.path.join(traceOutputFolder,f) - if os.path.isfile(file_path): - os.unlink(file_path) - if not os.path.isfile(goldenTraceFilePath): - print ("Cannot find golden Trace File 'llfi.stat.trace.prof.txt'") - + global traceOutputFolder, goldenTraceFilePath + traceOutputFolder = os.path.abspath( + os.path.join(currentpath, "../trace_report_output") + ) + # print (traceOutputFolder) + goldenTraceFilePath = os.path.abspath( + os.path.join(currentpath, "../baseline/llfi.stat.trace.prof.txt") + ) + if not os.path.exists(traceOutputFolder): + os.makedirs(traceOutputFolder) + else: + # Remove the contents in traceOutputFolder + for f in os.listdir(traceOutputFolder): + file_path = os.path.join(traceOutputFolder, f) + if os.path.isfile(file_path): + os.unlink(file_path) + if not os.path.isfile(goldenTraceFilePath): + print("Cannot find golden Trace File 'llfi.stat.trace.prof.txt'") def executeTraceDiff(): - traceFileCount = 0 - log_path =os.path.abspath(os.path.join(traceOutputFolder, "stderr_log.txt")) - log_file =open(log_path ,'w') - #Parse the goldenTraceFile path - tempgoldenTraceFilePath = goldenTraceFilePath - while "(" in tempgoldenTraceFilePath and not "\\(" in tempgoldenTraceFilePath: - tempgoldenTraceFilePath = tempgoldenTraceFilePath[:tempgoldenTraceFilePath.find("(")]+'\\('+ tempgoldenTraceFilePath[tempgoldenTraceFilePath.find("(")+1:] - while ")" in tempgoldenTraceFilePath and not "\\)" in tempgoldenTraceFilePath: - tempgoldenTraceFilePath = tempgoldenTraceFilePath[:tempgoldenTraceFilePath.find(")")]+'\\)'+ tempgoldenTraceFilePath[tempgoldenTraceFilePath.find(")")+1:] - tempScriptdir = scriptdir - #Parse the scriptdir path - while "(" in tempScriptdir and not "\\(" in tempScriptdir: - tempScriptdir = tempScriptdir[:tempScriptdir.find("(")]+'\\('+ tempScriptdir[tempScriptdir.find("(")+1:] - while ")" in tempScriptdir and not "\\)" in tempScriptdir: - tempScriptdir = tempScriptdir[:tempScriptdir.find(")")]+'\\)'+ tempScriptdir[tempScriptdir.find(")")+1:] - temptraceOutputFolder = traceOutputFolder - #Parse the traceOutputFolder path - while "(" in temptraceOutputFolder and not "\\(" in temptraceOutputFolder: - temptraceOutputFolder = temptraceOutputFolder[:temptraceOutputFolder.find("(")]+'\\('+ temptraceOutputFolder[temptraceOutputFolder.find("(")+1:] - while ")" in temptraceOutputFolder and not "\\)" in temptraceOutputFolder: - temptraceOutputFolder = temptraceOutputFolder[:temptraceOutputFolder.find(")")]+'\\)'+ temptraceOutputFolder[temptraceOutputFolder.find(")")+1:] - for file in os.listdir(currentpath): - if file.endswith(".txt") and file.startswith("llfi.stat.trace."): - cmd = tempScriptdir+"/tracediff "+tempgoldenTraceFilePath+" "+file+" > "+temptraceOutputFolder+"/TraceDiffReportFile"+file[file.find("llfi.stat.trace")+len("llfi.stat.trace"):] - p =subprocess.call(cmd,shell=True,stderr=log_file) - traceFileCount += 1 - #Check if trace files present, if not show error messages - if not traceFileCount > 0: - print ("Cannot find Trace input files.") - print ("Please make sure you are running this script in the llfi_stat_output folder") + traceFileCount = 0 + log_path = os.path.abspath(os.path.join(traceOutputFolder, "stderr_log.txt")) + with open(log_path, "w") as log_file: + for file in os.listdir(currentpath): + if file.endswith(".txt") and file.startswith("llfi.stat.trace."): + suffix = file[file.find("llfi.stat.trace") + len("llfi.stat.trace") :] + out_path = os.path.join( + traceOutputFolder, "TraceDiffReportFile" + suffix + ) + with open(out_path, "w") as out_file: + subprocess.call( + [ + os.path.join(scriptdir, "tracediff"), + goldenTraceFilePath, + file, + ], + stdout=out_file, + stderr=log_file, + ) + traceFileCount += 1 + if traceFileCount == 0: + print("Cannot find Trace input files.") + print( + "Please make sure you are running this script in the llfi_stat_output folder" + ) + def generateDotFile(): - log_path =os.path.abspath(os.path.join(traceOutputFolder, "stderr_log.txt")) - log_file =open(log_path ,'a') - goldenTraceDotFile = os.path.abspath(os.path.join(currentpath, "../../../llfi.stat.graph.dot")) - if not os.path.isfile(goldenTraceDotFile): - goldenTraceDotFile = os.path.abspath(os.path.join(currentpath, "../../llfi.stat.graph.dot")) - if not os.path.isfile(goldenTraceDotFile): - print ("Cannot find golden Trace Dot File 'llfi.stat.graph.dot'") - - #Parse the goldenTraceFile path - tempgoldenTraceFilePath = goldenTraceFilePath - while "(" in tempgoldenTraceFilePath and not "\\(" in tempgoldenTraceFilePath: - tempgoldenTraceFilePath = tempgoldenTraceFilePath[:tempgoldenTraceFilePath.find("(")]+'\\('+ tempgoldenTraceFilePath[tempgoldenTraceFilePath.find("(")+1:] - while ")" in tempgoldenTraceFilePath and not "\\)" in tempgoldenTraceFilePath: - tempgoldenTraceFilePath = tempgoldenTraceFilePath[:tempgoldenTraceFilePath.find(")")]+'\\)'+ tempgoldenTraceFilePath[tempgoldenTraceFilePath.find(")")+1:] - tempScriptdir = scriptdir - #Parse the scriptdir path - while "(" in tempScriptdir and not "\\(" in tempScriptdir: - tempScriptdir = tempScriptdir[:tempScriptdir.find("(")]+'\\('+ tempScriptdir[tempScriptdir.find("(")+1:] - while ")" in tempScriptdir and not "\\)" in tempScriptdir: - tempScriptdir = tempScriptdir[:tempScriptdir.find(")")]+'\\)'+ tempScriptdir[tempScriptdir.find(")")+1:] - temptraceOutputFolder = traceOutputFolder - #Parse the traceOutputFolder path - while "(" in temptraceOutputFolder and not "\\(" in temptraceOutputFolder: - temptraceOutputFolder = temptraceOutputFolder[:temptraceOutputFolder.find("(")]+'\\('+ temptraceOutputFolder[temptraceOutputFolder.find("(")+1:] - while ")" in temptraceOutputFolder and not "\\)" in temptraceOutputFolder: - temptraceOutputFolder = temptraceOutputFolder[:temptraceOutputFolder.find(")")]+'\\)'+ temptraceOutputFolder[temptraceOutputFolder.find(")")+1:] - tempgoldenTraceDotFile = goldenTraceDotFile - #Parse the traceOutputFolder path - while "(" in tempgoldenTraceDotFile and not "\\(" in tempgoldenTraceDotFile: - tempgoldenTraceDotFile = tempgoldenTraceDotFile[:tempgoldenTraceDotFile.find("(")]+'\\('+ tempgoldenTraceDotFile[tempgoldenTraceDotFile.find("(")+1:] - while ")" in tempgoldenTraceDotFile and not "\\)" in tempgoldenTraceDotFile: - tempgoldenTraceDotFile = tempgoldenTraceDotFile[:tempgoldenTraceDotFile.find(")")]+'\\)'+ tempgoldenTraceDotFile[tempgoldenTraceDotFile.find(")")+1:] - - - for file in os.listdir(traceOutputFolder): - if file.startswith("TraceDiffReportFile"): - # Parse the name - name = file[file.find("TraceDiffReportFile")+len("TraceDiffReportFile"):] - name = name.replace("txt", "dot") - cmd = tempScriptdir+"/traceontograph "+temptraceOutputFolder+"/"+file+" "+tempgoldenTraceDotFile+" > "+ temptraceOutputFolder+"/TraceGraph"+name - p =subprocess.call(cmd,shell=True,stderr=log_file) + log_path = os.path.abspath(os.path.join(traceOutputFolder, "stderr_log.txt")) + goldenTraceDotFile = os.path.abspath( + os.path.join(currentpath, "../../../llfi.stat.graph.dot") + ) + if not os.path.isfile(goldenTraceDotFile): + goldenTraceDotFile = os.path.abspath( + os.path.join(currentpath, "../../llfi.stat.graph.dot") + ) + if not os.path.isfile(goldenTraceDotFile): + print("Cannot find golden Trace Dot File 'llfi.stat.graph.dot'") + + with open(log_path, "a") as log_file: + for file in os.listdir(traceOutputFolder): + if file.startswith("TraceDiffReportFile"): + name = file[len("TraceDiffReportFile") :].replace("txt", "dot") + out_path = os.path.join(traceOutputFolder, "TraceGraph" + name) + with open(out_path, "w") as out_file: + subprocess.call( + [ + os.path.join(scriptdir, "traceontograph"), + os.path.join(traceOutputFolder, file), + goldenTraceDotFile, + ], + stdout=out_file, + stderr=log_file, + ) def main(args): - global currentpath, scriptdir, traceOutputFolder, goldenTraceFilePath - parseArgs(args) - findPath() - makeTraceOutputFolder() - executeTraceDiff() - generateDotFile() - -if __name__=="__main__": - main(sys.argv[1:]) + global currentpath, scriptdir, traceOutputFolder, goldenTraceFilePath + parseArgs(args) + findPath() + makeTraceOutputFolder() + executeTraceDiff() + generateDotFile() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tools/tracetools.py b/tools/tracetools.py index 5e80251c..da93d8f0 100644 --- a/tools/tracetools.py +++ b/tools/tracetools.py @@ -1,9 +1,9 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 -#traceTools.py -#Author: Sam Coulter -#This file contains library functions and classes for the llfi tracing and -#tracing analyses scripts +# traceTools.py +# Author: Sam Coulter +# This file contains library functions and classes for the llfi tracing and +# tracing analyses scripts import sys import os @@ -11,513 +11,577 @@ import itertools import difflib - debugFlag = 10 + def debug(text, level=5): - global debugFlag - - if debugFlag > level: - print(text) + global debugFlag + + if debugFlag > level: + print(text) + goldenRemovedCount = [] faultyRemovedCount = [] + class diffBlock: - def __init__(self, lines): - - debug("\n\tCreating a diffBlock ") - debug(lines) - origHeader, newHeader = lines[0].replace('@',' ').replace('+',' ').replace('-',' ').split() - origsplit = origHeader.split(',') - newsplit = newHeader.split(',') - self.origStart = int(origsplit[0]) - self.newStart = int(newsplit[0]) - - self.preDiff = None - self.postDiff = None - - self.origLines = [] - self.newLines = [] - - if "+" not in lines[1] and "-" not in lines[1]: - if "S" in lines[1]: #See ugly hack in the diffReport init - lines[1] = lines[1][2:] - self.preDiff = lines.pop(1) - self.origStart += 1 - self.newStart += 1 - - if "+" not in lines[-1] and "-" not in lines[-1]: - self.postDiff = lines.pop(len(lines)-1) - - for line in lines[1:]: - if line==None: continue - if "-" in line: - self.origLines.append(line) - # print("Appending ", line, "to origLines") - if "+" in line: - self.newLines.append(line) - # print("Appending ", line, "to newLines") - - debug("\tDiffblock lines") - debug("\t" + "\n\t".join(lines)) - - self.origLength = len(self.origLines) - # assert(self.origLength > 0 and "self.origLines is empty") - self.newLength = len(self.newLines) - # assert(self.newLength > 0 and "self.newLines is empty") - - debug("\t" + "\n\t".join(self.origLines)) - debug("\t" + "\n\t".join(self.newLines)) - - #print some info for debugging - def printdebug(self): - print(self.origStart, self.newStart) - print('\n'.join(self.origLines)) - print('\n'.join(self.newLines)) - - #print the block analysis summary - def getSummary(self, adj=0): - origStart = self.origStart + adj - newStart = self.newStart + adj - DataDiffs = [] - CtrlDiffs = [] - instanceList = [] - izip = itertools.zip_longest(self.origLines, self.newLines) - - instance = diffInstance(0,0,0,0) - for i, (g, f) in enumerate(izip): - if not g == None: - debug("Golden Line " + str(g) ) - g = diffLine(g) - if not f == None: - debug("Fault Line " + str(f) ) - f = diffLine(f) - if g and f: - if g.ID == f.ID: - if instance.type != 1: - if (instance.summary != None): - instanceList.append(instance.summary()) - instance = diffInstance(1, findAdjustedPosition(origStart, goldenRemovedCount), \ - findAdjustedPosition(newStart, faultyRemovedCount), i) - instance.add("Data Diff: ID: " + str(g.ID) + " OPCode: " + str(g.OPCode) + \ - " Value: " + str(g.Value) + " \\ " + str(f.Value)) - instance.incOrigLength() - instance.incNewLength() - if (instance.summary != None): - instanceList.append(instance.summary()) - if (len(instanceList) < 2): - debug("Instance list has zero or one elements") - return None - return instanceList[1] + def __init__(self, lines): + + debug("\n\tCreating a diffBlock ") + debug(lines) + origHeader, newHeader = ( + lines[0].replace("@", " ").replace("+", " ").replace("-", " ").split() + ) + origsplit = origHeader.split(",") + newsplit = newHeader.split(",") + self.origStart = int(origsplit[0]) + self.newStart = int(newsplit[0]) + + self.preDiff = None + self.postDiff = None + + self.origLines = [] + self.newLines = [] + + if "+" not in lines[1] and "-" not in lines[1]: + if "S" in lines[1]: # See ugly hack in the diffReport init + lines[1] = lines[1][2:] + self.preDiff = lines.pop(1) + self.origStart += 1 + self.newStart += 1 + + if "+" not in lines[-1] and "-" not in lines[-1]: + self.postDiff = lines.pop(len(lines) - 1) + + for line in lines[1:]: + if line == None: + continue + if "-" in line: + self.origLines.append(line) + # print("Appending ", line, "to origLines") + if "+" in line: + self.newLines.append(line) + # print("Appending ", line, "to newLines") + + debug("\tDiffblock lines") + debug("\t" + "\n\t".join(lines)) + + self.origLength = len(self.origLines) + # assert(self.origLength > 0 and "self.origLines is empty") + self.newLength = len(self.newLines) + # assert(self.newLength > 0 and "self.newLines is empty") + + debug("\t" + "\n\t".join(self.origLines)) + debug("\t" + "\n\t".join(self.newLines)) + + # print some info for debugging + def printdebug(self): + print(self.origStart, self.newStart) + print("\n".join(self.origLines)) + print("\n".join(self.newLines)) + + # print the block analysis summary + def getSummary(self, adj=0): + origStart = self.origStart + adj + newStart = self.newStart + adj + DataDiffs = [] + CtrlDiffs = [] + instanceList = [] + izip = itertools.zip_longest(self.origLines, self.newLines) + + instance = diffInstance(0, 0, 0, 0) + for i, (g, f) in enumerate(izip): + if not g == None: + debug("Golden Line " + str(g)) + g = diffLine(g) + if not f == None: + debug("Fault Line " + str(f)) + f = diffLine(f) + if g and f: + if g.ID == f.ID: + if instance.type != 1: + if instance.summary != None: + instanceList.append(instance.summary()) + instance = diffInstance( + 1, + findAdjustedPosition(origStart, goldenRemovedCount), + findAdjustedPosition(newStart, faultyRemovedCount), + i, + ) + instance.add( + "Data Diff: ID: " + + str(g.ID) + + " OPCode: " + + str(g.OPCode) + + " Value: " + + str(g.Value) + + " \\ " + + str(f.Value) + ) + instance.incOrigLength() + instance.incNewLength() + if instance.summary != None: + instanceList.append(instance.summary()) + if len(instanceList) < 2: + debug("Instance list has zero or one elements") + return None + return instanceList[1] + class ctrlDiffBlock(diffBlock): - def getRange(self): - debug("Printing ctrlDiffBlock Range") # String has to be rturned this way for split to not fail later - debug(str(self.origStart) + " " + str(self.origLength) + " " + str(self.newStart) + " " + str(self.newLength)) - return self.origStart, self.origLength, \ - self.newStart, self.newLength - - def getSummary(self, adj=0): - origStart = self.origStart + adj - newStart = self.newStart + adj - DataDiffs = [] - CtrlDiffs = [] - instanceList = [] - - debug("ctrlDiffBlock getSummaryCall: " + str(adj)) - - izip = itertools.zip_longest(self.origLines, self.newLines) - - instance = diffInstance(0,0,0,0) - for i, (g, f) in enumerate(izip): - if g and f: - if instance.type != 2: - if (instance.summary != None): + def getRange(self): + debug( + "Printing ctrlDiffBlock Range" + ) # String has to be rturned this way for split to not fail later + debug( + str(self.origStart) + + " " + + str(self.origLength) + + " " + + str(self.newStart) + + " " + + str(self.newLength) + ) + return self.origStart, self.origLength, self.newStart, self.newLength + + def getSummary(self, adj=0): + origStart = self.origStart + adj + newStart = self.newStart + adj + DataDiffs = [] + CtrlDiffs = [] + instanceList = [] + + debug("ctrlDiffBlock getSummaryCall: " + str(adj)) + + izip = itertools.zip_longest(self.origLines, self.newLines) + + instance = diffInstance(0, 0, 0, 0) + for i, (g, f) in enumerate(izip): + if g and f: + if instance.type != 2: + if instance.summary != None: + instanceList.append( + instance.summary(self.preDiff, self.postDiff) + ) + instance = diffInstance(2, origStart, newStart, i) + instance.add("Ctrl Diff: ID: " + str(g[1:]) + " \\ " + str(f[1:])) + instance.incOrigLength() + instance.incNewLength() + if g and not f: + if instance.type != 2: + if instance.summary != None: + instanceList.append( + instance.summary(self.preDiff, self.postDiff) + ) + instance = diffInstance(2, origStart, newStart, i) + instance.add("Ctrl Diff: ID: " + str(g[1:]) + " \\ None") + instance.incOrigLength() + if f and not g: + if instance.type != 2: + if instance.summary != None: + instanceList.append( + instance.summary(self.preDiff, self.postDiff) + ) + instance = diffInstance(2, origStart, newStart, i) + instance.add("Ctrl Diff: ID: " + "None \\ " + str(f[1:])) + instance.incNewLength() + if instance.summary != None: instanceList.append(instance.summary(self.preDiff, self.postDiff)) - instance = diffInstance(2, origStart, newStart, i) - instance.add("Ctrl Diff: ID: " + str(g[1:]) + " \\ " + str(f[1:])) - instance.incOrigLength() - instance.incNewLength() - if g and not f: - if instance.type != 2: - if (instance.summary != None): - instanceList.append(instance.summary(self.preDiff, self.postDiff)) - instance = diffInstance(2, origStart, newStart, i) - instance.add("Ctrl Diff: ID: " + str(g[1:]) + " \\ None") - instance.incOrigLength() - if f and not g: - if instance.type != 2: - if (instance.summary != None): - instanceList.append(instance.summary(self.preDiff, self.postDiff)) - instance = diffInstance(2, origStart, newStart, i) - instance.add("Ctrl Diff: ID: " + "None \\ " + str(f[1:])) - instance.incNewLength() - if (instance.summary != None): - instanceList.append(instance.summary(self.preDiff, self.postDiff)) - return instanceList[1] - -def removeRangeFromLines(goldenLines, faultyLines, xxx_todo_changeme, adj = 0): - (gStart, gLength, fStart, fLength) = xxx_todo_changeme - global goldenRemovedCount - global faultyRemovedCount - - debug("\n\nRemovingRangeFromLines()") - - i = 0 - debug("GLen "+ str(gLength)) - debug("GStart " + str(gStart)) - while (i < gLength): - goldenLines[gStart+i-1] = "" - i += 1 - debug(str(i)) - goldenRemovedCount.append((gStart + adj, gLength)) - i = 0 - debug("FLen " +str(fLength)) - debug("FStart " + str(fStart)) - debug("FLines") - debug("\n".join(faultyLines)) - while (i < fLength): - faultyLines[fStart+i-1] = "" - i += 1 - debug(str(i)) - faultyRemovedCount.append((fStart + adj, fLength)) - debug("\nGolden after removal::::") - debug('\n'.join(goldenLines)) - debug("\nFaulty After removal::::") - debug('\n'.join(faultyLines)) - return goldenLines, faultyLines + return instanceList[1] -def findAdjustedPosition(position, remArray): - i = 0 - while i < len(remArray): - location, count = remArray[i] - if position >= location: - position = position + count - else: - return position - i += 1 - return position +def removeRangeFromLines(goldenLines, faultyLines, xxx_todo_changeme, adj=0): + gStart, gLength, fStart, fLength = xxx_todo_changeme + global goldenRemovedCount + global faultyRemovedCount -class diffInstance: - def __init__(self, insttype, origstart, newstart, adj): - debug("\t\tCreating a diffInstance") - self.origStart = origstart + adj - self.origLength = 0 - self.origEnd = 0 - self.newStart = newstart + adj - self.newLength = 0 - self.newEnd = 0 - self.type = insttype - self.lines = [] - - def add(self, line): - self.lines.append(line) - - def summary(self, preDiff=None, postDiff=None): - if len(self.lines) > 0: - self.origEnd = self.origStart + self.origLength - self.newEnd = self.newStart + self.newLength - header = "\nDiff@ inst # " + str(self.origStart) + "\\" + str(self.newStart) \ - + " -> inst # " + str(self.origEnd) + "\\" + str(self.newEnd) + "\n" - if preDiff != None: - header += "Pre Diff: ID: " + str(preDiff) + "\n" - if postDiff != None: - final = header + '\n'.join(self.lines) + "\nPost Diff: ID:" + postDiff - else: - final = header + '\n'.join(self.lines) - return final - else: - return None - - def incOrigLength(self): - self.origLength += 1 - - def incNewLength(self): - self.newLength += 1 + debug("\n\nRemovingRangeFromLines()") + i = 0 + debug("GLen " + str(gLength)) + debug("GStart " + str(gStart)) + while i < gLength: + goldenLines[gStart + i - 1] = "" + i += 1 + debug(str(i)) + goldenRemovedCount.append((gStart + adj, gLength)) + i = 0 + debug("FLen " + str(fLength)) + debug("FStart " + str(fStart)) + debug("FLines") + debug("\n".join(faultyLines)) + while i < fLength: + faultyLines[fStart + i - 1] = "" + i += 1 + debug(str(i)) + faultyRemovedCount.append((fStart + adj, fLength)) + debug("\nGolden after removal::::") + debug("\n".join(goldenLines)) + debug("\nFaulty After removal::::") + debug("\n".join(faultyLines)) + return goldenLines, faultyLines -class diffReport: - def __init__(self, goldenLines, faultyLines, startPoint, injectedID): - self.injectedID = injectedID - debug("Starting a diffReport, startpoint = " + str(startPoint)) - self.startPoint = startPoint - self.blocks = [] - - #perform ctrl diff analysis - goldenIDs = goldenLines[:] - faultyIDs = faultyLines[:] - goldenIDs = trimLinesToCtrlIDs(goldenIDs) - faultyIDs = trimLinesToCtrlIDs(faultyIDs) - -#This ugly hack forces the difflib routine to prioritize certain ctrl flow matches. -# TODO: the fundamental problem here is unix diff is not greedy -# so we might need to come up with a comprehensive fix -# The hack might not always work. Jiesheng +def findAdjustedPosition(position, remArray): i = 0 - while i < len(goldenIDs): - if i < len(faultyIDs) and goldenIDs[i] == faultyIDs[i]: - goldenIDs[i] = "S" + goldenIDs[i] - faultyIDs[i] = "S" + faultyIDs[i] - else: - break - i += 1 - - ctrldiff = list(difflib.unified_diff(goldenIDs[:], faultyIDs[:], n=1, lineterm='')) - - if ctrldiff: - ctrldiff.pop(0) - ctrldiff.pop(0) - - debug("\n".join(ctrldiff)) - debug("Length = " + str(len(ctrldiff))) - - i = 0 - length = 1 - start = None - - while (i < len(ctrldiff)): - if "@@ " in ctrldiff[i]: - if start != None: - debug("Calling ctrlDiffBlock constructor " + str(start) + " " + str(length)) - debug("\n".join(ctrldiff[start:start+length])) - cblock = ctrlDiffBlock(ctrldiff[start:start+length]) - self.blocks.append(cblock) - length = 1 - start = i + while i < len(remArray): + location, count = remArray[i] + if position >= location: + position = position + count else: - length += 1 + return position i += 1 - #Dont forget the last block in the diff! - if start != None: - debug("Calling ctrlDiffBlock constructor " + str(start) + " " + str(length)) - debug("\n".join(ctrldiff[start:start+length])) - cblock = ctrlDiffBlock(ctrldiff[start:start+length]) - self.blocks.append(cblock) - - debug("Golden Lines:\n" + "\n".join(goldenLines)) - debug("Faulty Lines:\n" + "\n".join(faultyLines)) - - debug("Removing ctrldiff ranges from lines") - for block in self.blocks: - goldenLines, faultyLines = removeRangeFromLines(goldenLines, faultyLines, \ - block.getRange(), self.startPoint) - debug("Golden Lines:\n" + "\n".join(goldenLines)) - debug("Faulty Lines:\n" + "\n".join(faultyLines)) - - goldenLines = [_f for _f in goldenLines if _f] - faultyLines = [_f for _f in faultyLines if _f] - - - datadiff = list(difflib.unified_diff(goldenLines, faultyLines, n=0, lineterm='')) - - if datadiff: - datadiff.pop(0) - datadiff.pop(0) - - #perform data diff analysis - i = 0 - length = 1 - start = None - - while (i < len(datadiff)): - if "@@ " in datadiff[i]: - if start != None: - self.blocks.append(diffBlock(datadiff[start:start+length])) - length = 1 - start = i + return position + + +class diffInstance: + def __init__(self, insttype, origstart, newstart, adj): + debug("\t\tCreating a diffInstance") + self.origStart = origstart + adj + self.origLength = 0 + self.origEnd = 0 + self.newStart = newstart + adj + self.newLength = 0 + self.newEnd = 0 + self.type = insttype + self.lines = [] + + def add(self, line): + self.lines.append(line) + + def summary(self, preDiff=None, postDiff=None): + if len(self.lines) > 0: + self.origEnd = self.origStart + self.origLength + self.newEnd = self.newStart + self.newLength + header = ( + "\nDiff@ inst # " + + str(self.origStart) + + "\\" + + str(self.newStart) + + " -> inst # " + + str(self.origEnd) + + "\\" + + str(self.newEnd) + + "\n" + ) + if preDiff != None: + header += "Pre Diff: ID: " + str(preDiff) + "\n" + if postDiff != None: + final = header + "\n".join(self.lines) + "\nPost Diff: ID:" + postDiff + else: + final = header + "\n".join(self.lines) + return final else: - length += 1 - i += 1 - #Dont forget the last block in the diff! - if start != None: - self.blocks.append(diffBlock(datadiff[start:start+length])) + return None + def incOrigLength(self): + self.origLength += 1 - def printSummary(self): - def sortFunc(x): - l = x.getSummary(self.startPoint) - if (l==None): return None - return l.split("\n")[1].replace('\\',' ').split()[3] - - #Sort the list of blocks by their starting point (wrt the golden trace) - self.blocks.sort(key = sortFunc) + def incNewLength(self): + self.newLength += 1 - for block in self.blocks: - if block.preDiff == None: - block.preDiff = self.injectedID - print(block.getSummary(self.startPoint)) -def trimLinesToCtrlIDs(lines): - i = 0 - while (i < len(lines)): - words = lines[i].split() - lines[i] = words[1] - i += 1 - return lines +class diffReport: + def __init__(self, goldenLines, faultyLines, startPoint, injectedID): + self.injectedID = injectedID + debug("Starting a diffReport, startpoint = " + str(startPoint)) + self.startPoint = startPoint + self.blocks = [] + + # perform ctrl diff analysis + goldenIDs = goldenLines[:] + faultyIDs = faultyLines[:] + goldenIDs = trimLinesToCtrlIDs(goldenIDs) + faultyIDs = trimLinesToCtrlIDs(faultyIDs) + + # This ugly hack forces the difflib routine to prioritize certain ctrl flow matches. + # TODO: the fundamental problem here is unix diff is not greedy + # so we might need to come up with a comprehensive fix + # The hack might not always work. Jiesheng + + i = 0 + while i < len(goldenIDs): + if i < len(faultyIDs) and goldenIDs[i] == faultyIDs[i]: + goldenIDs[i] = "S" + goldenIDs[i] + faultyIDs[i] = "S" + faultyIDs[i] + else: + break + i += 1 + + ctrldiff = list( + difflib.unified_diff(goldenIDs[:], faultyIDs[:], n=1, lineterm="") + ) + + if ctrldiff: + ctrldiff.pop(0) + ctrldiff.pop(0) + + debug("\n".join(ctrldiff)) + debug("Length = " + str(len(ctrldiff))) + + i = 0 + length = 1 + start = None + + while i < len(ctrldiff): + if "@@ " in ctrldiff[i]: + if start != None: + debug( + "Calling ctrlDiffBlock constructor " + + str(start) + + " " + + str(length) + ) + debug("\n".join(ctrldiff[start : start + length])) + cblock = ctrlDiffBlock(ctrldiff[start : start + length]) + self.blocks.append(cblock) + length = 1 + start = i + else: + length += 1 + i += 1 + # Dont forget the last block in the diff! + if start != None: + debug( + "Calling ctrlDiffBlock constructor " + + str(start) + + " " + + str(length) + ) + debug("\n".join(ctrldiff[start : start + length])) + cblock = ctrlDiffBlock(ctrldiff[start : start + length]) + self.blocks.append(cblock) + + debug("Golden Lines:\n" + "\n".join(goldenLines)) + debug("Faulty Lines:\n" + "\n".join(faultyLines)) + + debug("Removing ctrldiff ranges from lines") + for block in self.blocks: + goldenLines, faultyLines = removeRangeFromLines( + goldenLines, faultyLines, block.getRange(), self.startPoint + ) + debug("Golden Lines:\n" + "\n".join(goldenLines)) + debug("Faulty Lines:\n" + "\n".join(faultyLines)) + + goldenLines = [_f for _f in goldenLines if _f] + faultyLines = [_f for _f in faultyLines if _f] + + datadiff = list( + difflib.unified_diff(goldenLines, faultyLines, n=0, lineterm="") + ) + + if datadiff: + datadiff.pop(0) + datadiff.pop(0) + + # perform data diff analysis + i = 0 + length = 1 + start = None + + while i < len(datadiff): + if "@@ " in datadiff[i]: + if start != None: + self.blocks.append(diffBlock(datadiff[start : start + length])) + length = 1 + start = i + else: + length += 1 + i += 1 + # Dont forget the last block in the diff! + if start != None: + self.blocks.append(diffBlock(datadiff[start : start + length])) + + def printSummary(self): + def sortFunc(x): + l = x.getSummary(self.startPoint) + if l == None: + return None + return l.split("\n")[1].replace("\\", " ").split()[3] + + # Sort the list of blocks by their starting point (wrt the golden trace) + self.blocks.sort(key=sortFunc) + + for block in self.blocks: + if block.preDiff == None: + block.preDiff = self.injectedID + print(block.getSummary(self.startPoint)) + +def trimLinesToCtrlIDs(lines): + i = 0 + while i < len(lines): + words = lines[i].split() + lines[i] = words[1] + i += 1 + return lines class diffLine: - "Class to track line level differences in trace file" - - def __init__(self, rawLine): - if rawLine=="": return - self.raw = rawLine - elements = str(rawLine).split() - debug("RAWLINE: " + str(rawLine), 3) - assert( len(elements) >=5 and "Line doesn't have sufficient fields" ); - assert( elements[0] in ["ID:","-ID:","+ID:"] and "Can't find element ID" ); - assert( elements[2] == "OPCode:" and "Can't find field OpCode"); - assert( elements[4] == "Value:" and "Can't fine field Value" ); - self.ID = int(elements[1]) - self.OPCode = str(elements[3]) - self.Value = 0 - if (len(elements) > 5): - self.Value = int(elements[5],16) - - def _print(self): - print("ID:",self.ID, "OPCode", self.OPCode, "Value:", self.Value) - - def __str__(self): - return self.raw + "Class to track line level differences in trace file" + + def __init__(self, rawLine): + if rawLine == "": + return + self.raw = rawLine + elements = str(rawLine).split() + debug("RAWLINE: " + str(rawLine), 3) + assert len(elements) >= 5 and "Line doesn't have sufficient fields" + assert elements[0] in ["ID:", "-ID:", "+ID:"] and "Can't find element ID" + assert elements[2] == "OPCode:" and "Can't find field OpCode" + assert elements[4] == "Value:" and "Can't fine field Value" + self.ID = int(elements[1]) + self.OPCode = str(elements[3]) + self.Value = 0 + if len(elements) > 5: + self.Value = int(elements[5], 16) + + def _print(self): + print("ID:", self.ID, "OPCode", self.OPCode, "Value:", self.Value) + + def __str__(self): + return self.raw -class faultReport: - def __init__(self, lines): - self.instNumber = -1 - self.faultCount = -1 - self.faultID = -1 - self.faultOPCode = '' - self.goldValue = -1 - self.faultValues = [] - self.diffs = [] - - if lines[0] == "#FaultReport\n": - header = lines[1].split() - self.faultCount = int(header[0]) - self.instNumber = header[2] - - fault = lines[2].split() - self.faultID = int(fault[1]) - self.faultOPCode = fault[3] - self.goldValue = fault[5] - for i in range(self.faultCount): - self.faultValues.append(fault[7 + i]) - - i = 3 - while (i < len(lines)): - if "Diff" not in lines[i]: - break - else: - string = str(lines[i]) - if "@" in lines[i]: - string = "\n" + string - self.diffs.append(string) - i += 1 - else: - print("ERROR: Not a properly formed faultReport") - - def union(self, other): - if self.faultID == other.faultID: - self.faultCount += other.faultCount - self.diffs.extend(other.diffs) - self.faultValues.extend(other.faultValues) - - def report(self): - lines = [] - lines.append("#FaultReport\n") - header = str(self.faultCount) + " @ " + str(self.instNumber) + "\n" - lines.append(header) - faultline = "ID: " + str(self.faultID) + " OPCode: " + str(self.faultOPCode) - faultline += " Value: " + str(self.goldValue) + " / " - for val in self.faultValues: - faultline += " " + str(val) - faultline += '\n' - lines.append(faultline) - lines.extend(self.diffs) - return ''.join(lines) - - def getAffectedSet(self): - affectedInsts = set() - for diff in self.diffs: - if "@" in diff: - continue - else: - split = diff.split() - if "Data" in diff: - affectedInsts.add(int(split[3])) -# if "Ctrl" in diff: #Commenting out to remove ctrl diff -# if split[5] != "None": #affected instructions from being -# affectedInsts.add(int(split[5])) #coloured on the graph - if (int(self.faultID) in affectedInsts): - affectedInsts.remove(int(self.faultID)) - return affectedInsts - - def getAffectedEdgesSet(self): - affectedEdges = set() +class faultReport: + def __init__(self, lines): + self.instNumber = -1 + self.faultCount = -1 + self.faultID = -1 + self.faultOPCode = "" + self.goldValue = -1 + self.faultValues = [] + self.diffs = [] + + if lines[0] == "#FaultReport\n": + header = lines[1].split() + self.faultCount = int(header[0]) + self.instNumber = header[2] + + fault = lines[2].split() + self.faultID = int(fault[1]) + self.faultOPCode = fault[3] + self.goldValue = fault[5] + for i in range(self.faultCount): + self.faultValues.append(fault[7 + i]) + + i = 3 + while i < len(lines): + if "Diff" not in lines[i]: + break + else: + string = str(lines[i]) + if "@" in lines[i]: + string = "\n" + string + self.diffs.append(string) + i += 1 - i = 0 - while i+1 < len(self.diffs): - if "Diff@" in self.diffs[i] and "Pre Diff" in self.diffs[i+1]: - csplit = self.diffs[i+2].split() - edgeStart = int(self.diffs[i+1].split()[3]) - edgeEnd = None - if csplit[5] != "None": - edgeEnd = int(csplit[5]) - if (i+3 < len(self.diffs)): - affectedEdges.add((edgeEnd, int(self.diffs[i+3].split()[5]))) else: - d = i + 2 #Adjusting so we dont check the find the pre diff of the diff@ instance we - #are currently on. - while d < len(self.diffs): - if "Post Diff" in self.diffs[d]: - edgeEnd = self.diffs[d].split()[3] - d = len(self.diffs) - elif "Pre Diff" in self.diffs[d]: - d = len(self.diffs) #If we found a new ctrl diff block before finding a post diff, - #exit the loop early - d += 1 - affectedEdges.add((edgeStart, edgeEnd)) - i += 1 - - return affectedEdges + print("ERROR: Not a properly formed faultReport") + + def union(self, other): + if self.faultID == other.faultID: + self.faultCount += other.faultCount + self.diffs.extend(other.diffs) + self.faultValues.extend(other.faultValues) + + def report(self): + lines = [] + lines.append("#FaultReport\n") + header = str(self.faultCount) + " @ " + str(self.instNumber) + "\n" + lines.append(header) + faultline = "ID: " + str(self.faultID) + " OPCode: " + str(self.faultOPCode) + faultline += " Value: " + str(self.goldValue) + " / " + for val in self.faultValues: + faultline += " " + str(val) + faultline += "\n" + lines.append(faultline) + lines.extend(self.diffs) + return "".join(lines) + + def getAffectedSet(self): + affectedInsts = set() + for diff in self.diffs: + if "@" in diff: + continue + else: + split = diff.split() + if "Data" in diff: + affectedInsts.add(int(split[3])) + # if "Ctrl" in diff: #Commenting out to remove ctrl diff + # if split[5] != "None": #affected instructions from being + # affectedInsts.add(int(split[5])) #coloured on the graph + if int(self.faultID) in affectedInsts: + affectedInsts.remove(int(self.faultID)) + return affectedInsts + + def getAffectedEdgesSet(self): + affectedEdges = set() + + i = 0 + while i + 1 < len(self.diffs): + if "Diff@" in self.diffs[i] and "Pre Diff" in self.diffs[i + 1]: + csplit = self.diffs[i + 2].split() + edgeStart = int(self.diffs[i + 1].split()[3]) + edgeEnd = None + if csplit[5] != "None": + edgeEnd = int(csplit[5]) + if i + 3 < len(self.diffs): + affectedEdges.add((edgeEnd, int(self.diffs[i + 3].split()[5]))) + else: + d = ( + i + 2 + ) # Adjusting so we dont check the find the pre diff of the diff@ instance we + # are currently on. + while d < len(self.diffs): + if "Post Diff" in self.diffs[d]: + edgeEnd = self.diffs[d].split()[3] + d = len(self.diffs) + elif "Pre Diff" in self.diffs[d]: + d = len( + self.diffs + ) # If we found a new ctrl diff block before finding a post diff, + # exit the loop early + d += 1 + affectedEdges.add((edgeStart, edgeEnd)) + i += 1 + + return affectedEdges + def parseFaultReportsfromFile(target): - reports = [] - reportFile = open(target, 'r') - fileLines = reportFile.readlines() - - #Remove blank lines from list - i = 0 - length = len(fileLines) - while i < length: - if not fileLines[i].strip(): - fileLines.pop(i) - length -= 1 - i += 1 - - #Parse the faultReports - i = 0 - fileLineCount = len(fileLines) - - while (i < fileLineCount): - if "#FaultReport" in fileLines[i]: - temp = [] - temp.append(fileLines[i]) - i += 1 - while ("#FaultReport" not in fileLines[i]): - temp.append(fileLines[i]) + reports = [] + with open(target, "r") as reportFile: + fileLines = reportFile.readlines() + + # Remove blank lines from list + i = 0 + length = len(fileLines) + while i < length: + if not fileLines[i].strip(): + fileLines.pop(i) + length -= 1 i += 1 - if i >= fileLineCount: - break - reports.append(faultReport(temp)) - else: - i += 1 - if i >= fileLineCount: - break - return reports + # Parse the faultReports + i = 0 + fileLineCount = len(fileLines) + + while i < fileLineCount: + if "#FaultReport" in fileLines[i]: + temp = [] + temp.append(fileLines[i]) + i += 1 + while "#FaultReport" not in fileLines[i]: + temp.append(fileLines[i]) + i += 1 + if i >= fileLineCount: + break + reports.append(faultReport(temp)) + else: + i += 1 + if i >= fileLineCount: + break + return reports diff --git a/tools/traceunion.py b/tools/traceunion.py index 9dc1b7ae..3b81483b 100755 --- a/tools/traceunion.py +++ b/tools/traceunion.py @@ -1,55 +1,64 @@ #! /usr/bin/env python3 -#traceUnion.py -#Author: Sam Coulter -#This script will take any number (1+) of fault tracing reports as input, and output -#a combined (union'd) faultreport to standard input, use pipe redirection to -#save to file -#Example Usage: +# traceUnion.py +# Author: Sam Coulter +# This script will take any number (1+) of fault tracing reports as input, and output +# a combined (union'd) faultreport to standard input, use pipe redirection to +# save to file +# Example Usage: # ./traceUnion.py file1 file2 file3 ... fileN > finalFile +import sys + from tracetools import * prog = os.path.basename(sys.argv[0]) -def traceUnion(argv, output=0): - #save stdout so we can redirect it without mangling other python scripts - oldSTDOut = sys.stdout - if output != 0: - sys.stdout = open(output, "wb") +def traceUnion(argv, output=0): + # save stdout so we can redirect it without mangling other python scripts + oldSTDOut = sys.stdout + if output != 0: + sys.stdout = open(output, "wb") # intentional stdout redirect - reps = [] - for f in argv: - reps.extend(parseFaultReportsfromFile(f)) + reps = [] + for f in argv: + reps.extend(parseFaultReportsfromFile(f)) - i = 0 - x = 1 - length = len(reps) - while i < length: - while x < length: - if reps[i].faultID == reps[x].faultID: - reps[i].union(reps[x]) - reps.pop(x) - length = len(reps) - x += 1 - i += 1 + i = 0 + x = 1 + length = len(reps) + while i < length: + while x < length: + if reps[i].faultID == reps[x].faultID: + reps[i].union(reps[x]) + reps.pop(x) + length = len(reps) + x += 1 + i += 1 - for rep in reps: - print(rep.report()) + for rep in reps: + print(rep.report()) + # restore stdout + sys.stdout = oldSTDOut - #restore stdout - sys.stdout = oldSTDOut if __name__ == "__main__": - if len(sys.argv) >= 2 and (sys.argv[1] == '-h' or sys.argv[1] == '--help'): - print(("%(prog)s takes more than one input program trace difference summary file and combines them to one report\n\n" - "running option: %(prog)s file1 file2 ..." %{"prog": prog}), file=sys.stderr) - elif len(sys.argv) >= 3: - traceUnion(sys.argv[1:]) - else: - print("Error: running option: %(prog)s file1 file2 ..." %{"prog": prog}, file=sys.stderr) - exit(1) - + if len(sys.argv) >= 2 and (sys.argv[1] == "-h" or sys.argv[1] == "--help"): + print( + ( + "%(prog)s takes more than one input program trace difference summary file and combines them to one report\n\n" + "running option: %(prog)s file1 file2 ..." % {"prog": prog} + ), + file=sys.stderr, + ) + elif len(sys.argv) >= 3: + traceUnion(sys.argv[1:]) + else: + print( + "Error: running option: %(prog)s file1 file2 ..." % {"prog": prog}, + file=sys.stderr, + ) + sys.exit(1) diff --git a/tutorials/ISSRE19/1-sqrt/measure.py b/tutorials/ISSRE19/1-sqrt/measure.py index 9d20a8ef..91422845 100644 --- a/tutorials/ISSRE19/1-sqrt/measure.py +++ b/tutorials/ISSRE19/1-sqrt/measure.py @@ -11,63 +11,59 @@ errdir = curdir + "/llfi/error_output" # read golden output from ./baseline/golden_std_output -print ("Reading golden output...") +print("Reading golden output...") file_gld_out = baseline + "/golden_std_output" -print ("Complete.\n") +print("Complete.\n") # read filenames from ./std_output -print ("Reading filenames...") +print("Reading filenames...") path, dirs, files = os.walk(std_output).__next__() run_count = len(files) -print ("Complete. " + str(run_count) + " fault injection runs were performed\n") +print("Complete. " + str(run_count) + " fault injection runs were performed\n") # check for SDCs sdc_count = 0 benign_count = 0 crash_count = 0 hang_count = 0 -print ("Checking files...") +print("Checking files...") for f in range(0, run_count): - print ("Checking fault injection run " + str(f) + "...", end="\r") + print("Checking fault injection run " + str(f) + "...", end="\r") file_out = std_output + "/std_outputfile-run-0-" + str(f) try: file_err = open(errdir + "/errorfile-run-0-" + str(f)) error_msg = file_err.read() file_err.close() - except FileNotFoundError: # no error output + except FileNotFoundError: # no error output error_msg = "" - if ("hang" in error_msg): + if "hang" in error_msg: hang_count += 1 - elif ("crash" in error_msg): + elif "crash" in error_msg: crash_count += 1 elif filecmp.cmp(file_out, file_gld_out): benign_count += 1 else: sdc_count += 1 sys.stdout.write("\033[K") -print ("Complete.", end="\r") +print("Complete.", end="\r") # print results -print ("\n") -print ("SDC count = " + str(sdc_count)) -print ("Crash count = " + str(crash_count)) -print ("Benign count = " + str(benign_count)) -print ("Hang count = " + str(hang_count)) -print ("Total Fi runs = " + str(run_count)) +print("\n") +print("SDC count = " + str(sdc_count)) +print("Crash count = " + str(crash_count)) +print("Benign count = " + str(benign_count)) +print("Hang count = " + str(hang_count)) +print("Total Fi runs = " + str(run_count)) # print results to file -par_dir = os.path.split( os.path.abspath(os.path.join(os.path.dirname( __file__ ), os.pardir)))[1] -out = open(curdir + "/results.txt", 'w') +par_dir = os.path.split( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +)[1] +out = open(curdir + "/results.txt", "w") sys.stdout = out -print ("SDC count = " + str(sdc_count)) -print ("Crash count = " + str(crash_count)) -print ("Benign count = " + str(benign_count)) -print ("Hang count = " + str(hang_count)) -print ("Total Fi runs = " + str(run_count)) +print("SDC count = " + str(sdc_count)) +print("Crash count = " + str(crash_count)) +print("Benign count = " + str(benign_count)) +print("Hang count = " + str(hang_count)) +print("Total Fi runs = " + str(run_count)) out.close() - - - - - - diff --git a/tutorials/ISSRE19/2-matmult/measure.py b/tutorials/ISSRE19/2-matmult/measure.py index 9d20a8ef..91422845 100644 --- a/tutorials/ISSRE19/2-matmult/measure.py +++ b/tutorials/ISSRE19/2-matmult/measure.py @@ -11,63 +11,59 @@ errdir = curdir + "/llfi/error_output" # read golden output from ./baseline/golden_std_output -print ("Reading golden output...") +print("Reading golden output...") file_gld_out = baseline + "/golden_std_output" -print ("Complete.\n") +print("Complete.\n") # read filenames from ./std_output -print ("Reading filenames...") +print("Reading filenames...") path, dirs, files = os.walk(std_output).__next__() run_count = len(files) -print ("Complete. " + str(run_count) + " fault injection runs were performed\n") +print("Complete. " + str(run_count) + " fault injection runs were performed\n") # check for SDCs sdc_count = 0 benign_count = 0 crash_count = 0 hang_count = 0 -print ("Checking files...") +print("Checking files...") for f in range(0, run_count): - print ("Checking fault injection run " + str(f) + "...", end="\r") + print("Checking fault injection run " + str(f) + "...", end="\r") file_out = std_output + "/std_outputfile-run-0-" + str(f) try: file_err = open(errdir + "/errorfile-run-0-" + str(f)) error_msg = file_err.read() file_err.close() - except FileNotFoundError: # no error output + except FileNotFoundError: # no error output error_msg = "" - if ("hang" in error_msg): + if "hang" in error_msg: hang_count += 1 - elif ("crash" in error_msg): + elif "crash" in error_msg: crash_count += 1 elif filecmp.cmp(file_out, file_gld_out): benign_count += 1 else: sdc_count += 1 sys.stdout.write("\033[K") -print ("Complete.", end="\r") +print("Complete.", end="\r") # print results -print ("\n") -print ("SDC count = " + str(sdc_count)) -print ("Crash count = " + str(crash_count)) -print ("Benign count = " + str(benign_count)) -print ("Hang count = " + str(hang_count)) -print ("Total Fi runs = " + str(run_count)) +print("\n") +print("SDC count = " + str(sdc_count)) +print("Crash count = " + str(crash_count)) +print("Benign count = " + str(benign_count)) +print("Hang count = " + str(hang_count)) +print("Total Fi runs = " + str(run_count)) # print results to file -par_dir = os.path.split( os.path.abspath(os.path.join(os.path.dirname( __file__ ), os.pardir)))[1] -out = open(curdir + "/results.txt", 'w') +par_dir = os.path.split( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +)[1] +out = open(curdir + "/results.txt", "w") sys.stdout = out -print ("SDC count = " + str(sdc_count)) -print ("Crash count = " + str(crash_count)) -print ("Benign count = " + str(benign_count)) -print ("Hang count = " + str(hang_count)) -print ("Total Fi runs = " + str(run_count)) +print("SDC count = " + str(sdc_count)) +print("Crash count = " + str(crash_count)) +print("Benign count = " + str(benign_count)) +print("Hang count = " + str(hang_count)) +print("Total Fi runs = " + str(run_count)) out.close() - - - - - - diff --git a/tutorials/ISSRE19/3-matmult_trace/measure.py b/tutorials/ISSRE19/3-matmult_trace/measure.py index 9d20a8ef..91422845 100644 --- a/tutorials/ISSRE19/3-matmult_trace/measure.py +++ b/tutorials/ISSRE19/3-matmult_trace/measure.py @@ -11,63 +11,59 @@ errdir = curdir + "/llfi/error_output" # read golden output from ./baseline/golden_std_output -print ("Reading golden output...") +print("Reading golden output...") file_gld_out = baseline + "/golden_std_output" -print ("Complete.\n") +print("Complete.\n") # read filenames from ./std_output -print ("Reading filenames...") +print("Reading filenames...") path, dirs, files = os.walk(std_output).__next__() run_count = len(files) -print ("Complete. " + str(run_count) + " fault injection runs were performed\n") +print("Complete. " + str(run_count) + " fault injection runs were performed\n") # check for SDCs sdc_count = 0 benign_count = 0 crash_count = 0 hang_count = 0 -print ("Checking files...") +print("Checking files...") for f in range(0, run_count): - print ("Checking fault injection run " + str(f) + "...", end="\r") + print("Checking fault injection run " + str(f) + "...", end="\r") file_out = std_output + "/std_outputfile-run-0-" + str(f) try: file_err = open(errdir + "/errorfile-run-0-" + str(f)) error_msg = file_err.read() file_err.close() - except FileNotFoundError: # no error output + except FileNotFoundError: # no error output error_msg = "" - if ("hang" in error_msg): + if "hang" in error_msg: hang_count += 1 - elif ("crash" in error_msg): + elif "crash" in error_msg: crash_count += 1 elif filecmp.cmp(file_out, file_gld_out): benign_count += 1 else: sdc_count += 1 sys.stdout.write("\033[K") -print ("Complete.", end="\r") +print("Complete.", end="\r") # print results -print ("\n") -print ("SDC count = " + str(sdc_count)) -print ("Crash count = " + str(crash_count)) -print ("Benign count = " + str(benign_count)) -print ("Hang count = " + str(hang_count)) -print ("Total Fi runs = " + str(run_count)) +print("\n") +print("SDC count = " + str(sdc_count)) +print("Crash count = " + str(crash_count)) +print("Benign count = " + str(benign_count)) +print("Hang count = " + str(hang_count)) +print("Total Fi runs = " + str(run_count)) # print results to file -par_dir = os.path.split( os.path.abspath(os.path.join(os.path.dirname( __file__ ), os.pardir)))[1] -out = open(curdir + "/results.txt", 'w') +par_dir = os.path.split( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +)[1] +out = open(curdir + "/results.txt", "w") sys.stdout = out -print ("SDC count = " + str(sdc_count)) -print ("Crash count = " + str(crash_count)) -print ("Benign count = " + str(benign_count)) -print ("Hang count = " + str(hang_count)) -print ("Total Fi runs = " + str(run_count)) +print("SDC count = " + str(sdc_count)) +print("Crash count = " + str(crash_count)) +print("Benign count = " + str(benign_count)) +print("Hang count = " + str(hang_count)) +print("Total Fi runs = " + str(run_count)) out.close() - - - - - - diff --git a/web-app/.eslintrc.js b/web-app/.eslintrc.js deleted file mode 100644 index 5a7a0173..00000000 --- a/web-app/.eslintrc.js +++ /dev/null @@ -1,268 +0,0 @@ -module.exports = { - "env": { - "es6": true, - "node": true - }, - "extends": "eslint:recommended", - "parserOptions": { - "ecmaFeatures": { - "experimentalObjectRestSpread": true, - "jsx": true - }, - "sourceType": "module" - }, - "plugins": [ - "react" - ], - "rules": { - "accessor-pairs": "error", - "array-bracket-spacing": [ - "error", - "never" - ], - "array-callback-return": "error", - "arrow-body-style": "error", - "arrow-parens": [ - "error", - "as-needed" - ], - "arrow-spacing": "off", - "block-scoped-var": "off", - "block-spacing": [ - "error", - "never" - ], - "brace-style": "off", - "callback-return": "error", - "camelcase": "off", - "class-methods-use-this": "error", - "comma-dangle": "off", - "comma-spacing": "off", - "comma-style": [ - "error", - "last" - ], - "complexity": "error", - "computed-property-spacing": [ - "error", - "never" - ], - "consistent-return": "off", - "consistent-this": "off", - "curly": "off", - "default-case": "error", - "dot-location": "error", - "dot-notation": [ - "error", - { - "allowKeywords": true - } - ], - "eol-last": "off", - "eqeqeq": "off", - "func-call-spacing": "error", - "func-name-matching": "error", - "func-names": [ - "error", - "never" - ], - "func-style": "off", - "generator-star-spacing": "error", - "global-require": "error", - "guard-for-in": "error", - "handle-callback-err": "error", - "id-blacklist": "error", - "id-length": "off", - "id-match": "error", - "indent": "off", - "init-declarations": "off", - "jsx-quotes": [ - "error", - "prefer-double" - ], - "key-spacing": "off", - "keyword-spacing": "off", - "line-comment-position": "error", - "linebreak-style": [ - "error", - "unix" - ], - "lines-around-comment": "error", - "lines-around-directive": "off", - "max-depth": "error", - "max-len": "off", - "max-lines": "off", - "max-nested-callbacks": "error", - "max-params": "error", - "max-statements": "off", - "max-statements-per-line": "off", - "multiline-ternary": [ - "error", - "never" - ], - "new-cap": "error", - "new-parens": "error", - "newline-after-var": "off", - "newline-before-return": "off", - "newline-per-chained-call": "off", - "no-alert": "off", - "no-array-constructor": "error", - "no-bitwise": "error", - "no-caller": "error", - "no-catch-shadow": "off", - "no-confusing-arrow": "error", - "no-continue": "error", - "no-console":0, - "no-div-regex": "error", - "no-duplicate-imports": "error", - "no-else-return": "off", - "no-empty-function": "error", - "no-eq-null": "error", - "no-eval": "error", - "no-extend-native": "off", - "no-extra-bind": "error", - "no-extra-label": "error", - "no-extra-parens": "off", - "no-floating-decimal": "error", - "no-implicit-coercion": "error", - "no-implicit-globals": "error", - "no-implied-eval": "error", - "no-inline-comments": "error", - "no-inner-declarations": [ - "error", - "functions" - ], - "no-invalid-this": "error", - "no-iterator": "error", - "no-label-var": "error", - "no-labels": "error", - "no-lone-blocks": "error", - "no-lonely-if": "error", - "no-loop-func": "off", - "no-magic-numbers": "off", - "no-mixed-operators": "error", - "no-mixed-requires": "error", - "no-multi-spaces": "error", - "no-multi-str": "error", - "no-multiple-empty-lines": "off", - "no-native-reassign": "error", - "no-negated-condition": "error", - "no-negated-in-lhs": "error", - "no-nested-ternary": "error", - "no-new": "error", - "no-new-func": "error", - "no-new-object": "error", - "no-new-require": "error", - "no-new-wrappers": "error", - "no-octal-escape": "error", - "no-param-reassign": "error", - "no-path-concat": "error", - "no-plusplus": "off", - "no-process-env": "error", - "no-process-exit": "error", - "no-proto": "error", - "no-prototype-builtins": "error", - "no-redeclare": 0, - "no-restricted-globals": "error", - "no-restricted-imports": "error", - "no-restricted-modules": "error", - "no-restricted-properties": "error", - "no-restricted-syntax": "error", - "no-return-assign": "error", - "no-return-await": "error", - "no-script-url": "error", - "no-self-compare": "error", - "no-sequences": "error", - "no-shadow": "off", - "no-shadow-restricted-names": "error", - "no-spaced-func": "error", - "no-sync": "off", - "no-tabs": "off", - "no-template-curly-in-string": "error", - "no-ternary": "off", - "no-throw-literal": "error", - "no-trailing-spaces": "off", - "no-undef": 0, - "no-undef-init": "error", - "no-undefined": "error", - "no-unmodified-loop-condition": "error", - "no-unneeded-ternary": "off", - "no-unused-expressions": "off", - "no-unused-vars": 0, - "no-use-before-define": "off", - "no-useless-call": "error", - "no-useless-computed-key": "error", - "no-useless-concat": "off", - "no-useless-constructor": "error", - "no-useless-escape": "error", - "no-useless-rename": "error", - "no-useless-return": "error", - "no-var": "off", - "no-void": "error", - "no-warning-comments": "off", - "no-whitespace-before-property": "error", - "no-with": "error", - "object-curly-newline": "off", - "object-curly-spacing": "off", - "object-property-newline": [ - "error", - { - "allowMultiplePropertiesPerLine": true - } - ], - "object-shorthand": "off", - "one-var": "off", - "one-var-declaration-per-line": "off", - "operator-assignment": [ - "error", - "always" - ], - "operator-linebreak": "error", - "padded-blocks": "off", - "prefer-arrow-callback": "off", - "prefer-const": "error", - "prefer-numeric-literals": "error", - "prefer-reflect": "error", - "prefer-rest-params": "error", - "prefer-spread": "error", - "prefer-template": "off", - "quote-props": "off", - "quotes": "off", - "radix": [ - "error", - "as-needed" - ], - "require-jsdoc": "off", - "rest-spread-spacing": "error", - "semi": "error", - "semi-spacing": "off", - "sort-imports": "off", - "sort-keys": "off", - "sort-vars": "off", - "space-before-blocks": "off", - "space-before-function-paren": "off", - "space-in-parens": [ - "error", - "never" - ], - "space-infix-ops": "off", - "space-unary-ops": "off", - "spaced-comment": "off", - "strict": "error", - "symbol-description": "error", - "template-curly-spacing": "error", - "unicode-bom": [ - "error", - "never" - ], - "valid-jsdoc": "error", - "vars-on-top": "off", - "wrap-iife": "error", - "wrap-regex": "error", - "yield-star-spacing": "error", - "yoda": [ - "error", - "never" - ] - } -}; \ No newline at end of file diff --git a/web-app/README.MD b/web-app/README.MD deleted file mode 100644 index 021f7162..00000000 --- a/web-app/README.MD +++ /dev/null @@ -1,81 +0,0 @@ -## Dependencies: ## -Nodejs. -webpack - -## Steps to set up the development environment: ## -1. Download NodeJs (v10.19.0) -2. Install libraries: Go to the web-app directory and run "npm install". The required versions can be found in package.json file. -3. Install Webpack: In the same directory as step 3, run "sudo npm install -g webpack". -4. Configurate the LLFI root path for the server: -The program uses the environment variable `$llfibuild` as the path of the llfi build directory. -You can set the environment variable `$llfibuild` in your system to point it to the LLFI build directory in your local machine. - -**Start the server:** -Go to the /web-app/server folder and run "node server.js" - -**Start the front-end dev tool:** -Go to the web-app directory and run "webpack" or "webpack -w" - - - -## Overview of the Web Application ## - -Reference: [LLFI Web GUI](https://github.com/DependableSystemsLab/LLFI/wiki/Get-Started-with-using-LLFI-with-Web-GUI) - -The procedure for performing fault injections in LLTFI and view results using the Web GUI is described below. - -### Fault Injection Procedure using the GUI ### - -The first step is to select the target program for fault injection. The user can upload any standalone C/C++ source file by selecting; File->Open File. - -The figure below is a screenshot for opening a single file. - -![Alt text](images/fileUpload.png?raw=true) - -Once the file is uploaded: - -1) **Compile To IR**: After a C/C++ source file is uploaded, click on the "Compile To IR" button. When the program is successfully compiled, a success message with the name of IR file that is created will be displayed. The figure below is a screenshot of the successful completion of the "Compile To IR" step. - -![Alt text](images/CompileToIR.png?raw=true) - -2) **Instrument**: Once the intermediate representation (IR) file is created, click on the "Instrument" button. The instrument dialog box will open, where the user can select the configuration parameters (input.yaml). - -For Hardware Faults Injection, select "Hardware Injection", then select the Instruction Type, Register Location and trace options and then click the "Instrument" button in the dialog box. The figure below is an example of the Hardware Faults Instrument interface. - -![Alt text](images/Instrument.png?raw=true) - - - -3) **Profiling**: After instrumentation, click on the "Profiling" button. If the program requires any command line inputs, provide them in the "Command Line Input" text box before clicking "Profiling". - -The figure below is the successful completion of hardware fault profiling. - -![Alt text](images/Profiling.png?raw=true) - -4) **Runtime Options**: Click on the "Runtime Options" button to configure the fault injection. The user can also save the profile by checking the "Save Profile" checkbox before clicking "Submit". Click on the "Delete Run" button to delete a runtime option. - -For Hardware Fault Injection, the user should provide the 'Number of Runs' and 'Fault injection type' (mandatory), while other inputs are optional. The figure below is a screenshot of the Hardware Runtime Options interface. - -![Alt text](images/RuntimeOptions.png?raw=true) - - - -5) **Fault Injection**: Click on the "Inject Fault" button to perform fault injection. When fault injection is completed click on the "Fault Injection Status" and "Fault Summary" tabs to view the result of the fault injections. - -The figure below is a screenshot of the completion of a Hardware "Inject Fault", you can also see the "Fault Injection Status" table in the figure. - -![Alt text](images/InjectFault.png?raw=true) -![Alt text](images/InjectFault2.png?raw=true) - -6) **Trace Graph**: Select one or more traces from the "Fault Injection Status" tabs. Click on the "Trace Graph" button and a trace graph window will pop up. The trace graph contains the flow of the LLVM IR instructions, and the mapping information of LLVM IR instructions to the C source code (if applicable). - -![Alt text](images/TraceGraph.png?raw=true) - -**SDC Option**: By default, the SDC's are reported by making a diff between the std_outputfile and the golden_output. The user can customize the way SDC is generated by writing a script and providing the path of the script to the environment variable COMPARE. - - - diff --git a/web-app/images/CompileToIR.png b/web-app/images/CompileToIR.png deleted file mode 100644 index 6753124f..00000000 Binary files a/web-app/images/CompileToIR.png and /dev/null differ diff --git a/web-app/images/InjectFault.png b/web-app/images/InjectFault.png deleted file mode 100644 index c8bcd877..00000000 Binary files a/web-app/images/InjectFault.png and /dev/null differ diff --git a/web-app/images/InjectFault2.png b/web-app/images/InjectFault2.png deleted file mode 100644 index 1d2a9450..00000000 Binary files a/web-app/images/InjectFault2.png and /dev/null differ diff --git a/web-app/images/Instrument.png b/web-app/images/Instrument.png deleted file mode 100644 index 080c54d8..00000000 Binary files a/web-app/images/Instrument.png and /dev/null differ diff --git a/web-app/images/Profiling.png b/web-app/images/Profiling.png deleted file mode 100644 index 3c0f5dc5..00000000 Binary files a/web-app/images/Profiling.png and /dev/null differ diff --git a/web-app/images/RuntimeOptions.png b/web-app/images/RuntimeOptions.png deleted file mode 100644 index b2639676..00000000 Binary files a/web-app/images/RuntimeOptions.png and /dev/null differ diff --git a/web-app/images/TraceGraph.png b/web-app/images/TraceGraph.png deleted file mode 100644 index f0d9cdcc..00000000 Binary files a/web-app/images/TraceGraph.png and /dev/null differ diff --git a/web-app/images/fileUpload.png b/web-app/images/fileUpload.png deleted file mode 100644 index c929f2d4..00000000 Binary files a/web-app/images/fileUpload.png and /dev/null differ diff --git a/web-app/package.json b/web-app/package.json deleted file mode 100644 index cd77ffde..00000000 --- a/web-app/package.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "name": "llfi-web-app", - "version": "1.0.0", - "description": "npm install\r npm install -g webpack\r webpack\r webpack -w", - "main": "webpack.config.js", - "scripts": { - "dev": "./node_modules/.bin/webpack-dev-server --content-base src --inline --hot", - "test": "echo \"Error: no test specified\" && exit 1" - }, - "author": "Kenneth Song", - "license": "ISC", - "dependencies": { - "babel-loader": "^6.2.0", - "babel-plugin-add-module-exports": "^0.1.2", - "babel-plugin-react-html-attrs": "^2.0.0", - "babel-plugin-transform-class-properties": "^6.3.13", - "babel-plugin-transform-decorators-legacy": "^1.3.4", - "babel-preset-es2015": "^6.3.13", - "babel-preset-react": "^6.3.13", - "babel-preset-stage-0": "^6.3.13", - "chart.js": "^2.3.0", - "downloadbutton": "^1.0.0", - "eslint": "^3.10.2", - "express": "^4.14.0", - "formidable": "^1.0.17", - "react": "^0.14.6", - "react-dom": "^0.14.6", - "react-filtered-multiselect": "^0.4.2", - "reflux": "^0.4.1", - "webpack": "^1.12.9", - "webpack-dev-server": "^1.14.1" - }, - "devDependencies": { - "babel-core": "^6.17.0", - "babel-loader": "^6.2.5", - "body-parser": "^1.15.2", - "eslint": "^3.10.2", - "eslint-config-standard": "^6.2.1", - "eslint-plugin-promise": "^3.4.0", - "eslint-plugin-react": "^6.7.1", - "eslint-plugin-standard": "^2.0.1", - "react-bootstrap": "^0.30.5", - "react-select": "^1.0.0-rc.2", - "webpack": "^1.13.2" - }, - "repository": { - "type": "git", - "url": "\u0016https://github.com/DependableSystemsLab/LLFI" - } -} diff --git a/web-app/server/compileToIR.js b/web-app/server/compileToIR.js deleted file mode 100644 index f1ca6080..00000000 --- a/web-app/server/compileToIR.js +++ /dev/null @@ -1,79 +0,0 @@ -var fs = require('fs'); -var exec = require('child_process').exec; -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; -var execPromise = require('./utils/execPromise').execPromise; - -exports.processCompileToIR = function (req, res) { - - var errorStatus = false; - var fileName = req.body.fileName; - - // Extract filename without extension - fileName = fileName.replace(/\.[^/.]+$/, ""); - - // Cd to the user directory - var cdDirCmd = "cd ./uploads/" + req.ip +"/"; - - // The command to genrate make file - var generateMakeCmd = LLFI_BUILD_ROOT + "tools/GenerateMakefile --readable --all -o " + fileName + ".ll"; - - - var commands = [cdDirCmd + " && " + generateMakeCmd, cdDirCmd + " && " + "make"]; - - var consoleLog = []; - - // Execute the commands - commands.reduce(function(p, cmd) { - return p.then(function(results) { - return execPromise(cmd).then(function(stdout) { - if (stdout) results.push(stdout); - consoleLog = results; - return results; - }); - }); - }, Promise.resolve([])).then(function(results) { - // all done here, all results in the results array - - }, function(err) { - // error here - res.status(500); - res.send({error: err}); - console.log("err in compileToIR process", err); - errorStatus = true; - }).then(function() { - if (errorStatus) return; - var files = []; - // Read the compilred IR file - fs.readFile("./uploads/"+ req.ip+"/" + fileName + ".ll", 'utf8', function(err, data) { - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - var fileObj = {}; - fileObj.fileName = fileName + ".ll"; - fileObj.fileContent = data; - files.push(fileObj); - - // Load the makefile content - fs.readFile("./uploads/"+ req.ip+"/Makefile", 'utf8', function(err, data) { - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - var fileObj = {}; - fileObj.fileName = "Makefile"; - fileObj.fileContent = data; - files.push(fileObj); - var response = {files: files, consoleLog: consoleLog}; - console.log("CompileToIR success"); - res.send(response); - }); - }); - }); -}; \ No newline at end of file diff --git a/web-app/server/faultInjection.js b/web-app/server/faultInjection.js deleted file mode 100644 index 073268d5..00000000 --- a/web-app/server/faultInjection.js +++ /dev/null @@ -1,142 +0,0 @@ -var fs = require('fs'); -var readline = require('readline'); -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; -var execPromise = require('./utils/execPromise').execPromise; - -exports.processFaultInjection = function (req, res) { - - var errorStatus = false; - var fileName = req.body.fileName; - var runOptions = req.body.runOptions; - // Extract filename without extension - fileName = fileName.replace(/\.[^/.]+$/, ""); - var input = req.body.input; - var batchMode = req.body.injectionMode.isBatchMode; - - // Get the right injection command based on isBatchMode - // Todo: Batch mode is not fully supproted, the results of fault injection does not support batch Mode yet - var faultInjectionScript; - if (batchMode) { - faultInjectionScript = LLFI_BUILD_ROOT + "bin/batchInjectfault " + fileName + ".ll " + input; - } else { - faultInjectionScript = LLFI_BUILD_ROOT + "bin/injectfault " + "./llfi/" + fileName + "-faultinjection.exe " + input; - } - - var cdDirCmd = "cd ./uploads/" + req.ip +"/"; - var consoleLog = []; - var commands = []; - var faultInjectionStatus = []; - var faultSummary = {SDC: 0, Hanged: 0, Crashed: 0}; - commands.push(cdDirCmd + " && " + faultInjectionScript); - - // Execute the fault injection - commands.reduce(function(p, cmd) { - return p.then(function(results) { - return execPromise(cmd).then(function(stdout) { - results.push(stdout); - consoleLog = results; - return results; - }); - }); - }, Promise.resolve([])).then(function(results) { - if (errorStatus) return; - var statOutputDir = "./uploads/" + req.ip +"/llfi/llfi_stat_output/"; - var stdOutputDir = "./uploads/" + req.ip +"/llfi/std_output/"; - var errorDir = "./uploads/" + req.ip +"/llfi/error_output/"; - var currentRun = 0; - var goldenOutputFile = "./uploads/" + req.ip +"/llfi/baseline/golden_std_output"; - var goldenOutput = ""; - // Read golden output - fs.readFile(goldenOutputFile, 'utf8', function(err, data) { - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - goldenOutput = data; - - // Get injection stats of each run - for (var runOption = 0 ; runOption < runOptions.length; runOption++) { - for (var run = 0; run < runOptions[runOption].numOfRuns; run++) { - // Read stats from llfi.stat.fi.injectedfaults files - var injectedfaultsStatsFileName = "llfi.stat.fi.injectedfaults." + runOption + "-" + run + ".txt"; - try { - var injectedfaultsStatsData = fs.readFileSync(statOutputDir + injectedfaultsStatsFileName, 'utf8'); - } catch (err) { - // If the file is not found, close the request, return error. - // Todo: need to return the proper LLFI injection status when the [llfi.stat.fi.injectedfaults] file is not found - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - var runNumber = currentRun; - var runIndex = runOption + "-" + run; - var injectionType = getStatusValue("fi_type", injectedfaultsStatsData); - var index = getStatusValue("fi_index", injectedfaultsStatsData); - var cycle = getStatusValue("fi_cycle", injectedfaultsStatsData); - var bit = getStatusValue("fi_bit", injectedfaultsStatsData); - var status = "Injected"; - // Get SDC occurance stats - var stdOutputFileName = "std_outputfile-run-" + runOption + "-" + run; - try { - var stdOutputData = fs.readFileSync(stdOutputDir + stdOutputFileName, 'utf8'); - } catch (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - var sdc = goldenOutput === stdOutputData ? "Not Occurred" : "Occurred"; - if (sdc === "Occurred") faultSummary.SDC ++; - // Get result status - var result = "Nil"; - var errorFileName = "errorfile-run-" + runOption + "-" + run; - if (fs.existsSync(errorDir + errorFileName)) { - try { - var errorFileData = fs.readFileSync(errorDir + errorFileName, 'utf8'); - } catch (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - if (errorFileData.includes("hanged")) { - result = "Hanged"; - faultSummary.Hanged ++; - } else if (errorFileData.includes("crashed")) { - result = "Crashed"; - faultSummary.Crashed ++; - } - } - faultInjectionStatus[currentRun] = {runIndex, runNumber, injectionType, index, cycle, bit, sdc, result, status}; - currentRun ++; - } - } - var results = {faultInjectionStatus, consoleLog, faultSummary}; - res.send(results); - console.log("faultInjection success"); - }); - - }, function(err) { - // error here - if (errorStatus) return; - res.status(500); - res.send({error: err}); - console.log("err in faultInjection process", err); - errorStatus = true; - }); -}; - -// Parse the file data to get the value of a status -var getStatusValue = function (statusType, fileData) { - var keyword = statusType + "="; - var value = fileData.split(keyword)[1]; - value = value.split(",")[0]; - return value; -}; \ No newline at end of file diff --git a/web-app/server/fileUpload.js b/web-app/server/fileUpload.js deleted file mode 100644 index a79c400b..00000000 --- a/web-app/server/fileUpload.js +++ /dev/null @@ -1,72 +0,0 @@ -var path = require('path'); -var formidable = require('formidable'); -var fs = require('fs'); -var Port = 8080; -var exec = require('child_process').exec; - -exports.processFileUpload = function (req, res) { - // create an incoming form object - var form = new formidable.IncomingForm(); - - var clientIP = req.ip; - // specify that we want to allow the user to upload multiple files in a single request - form.multiples = false; - - - // The dir name for each client - var dirName = "./uploads/" + clientIP + "/"; - - exec("rm -rf " + dirName + "*", function(err, stdout) { - if (err) { - console.log("err in clearing dir", dirName, err); - res.status(500); - res.send(err); - } - }); - - // store all uploads in the /uploads directory - form.uploadDir = path.join(__dirname, dirName); - // Make a dir to store the files from a client - if (!fs.existsSync(dirName)) { - fs.mkdirSync(dirName); - } - // every time a file has been uploaded successfully, - // rename it to it's orignal name - form.on('file', function(field, file) { - this.fileName = file.name; - fs.rename(file.path, path.join(form.uploadDir, file.name), function (err) { - if (err) { - console.log("An error has occured in file rename, ", err); - } - else { - - // Send the file content back to front end - fs.readFile(dirName + file.name, 'utf8', function(err, data) { - var fileObj = {}; - fileObj.fileName = file.name; - fileObj.fileContent = data; - if (err) { - res.status(500); - res.send(err); - console.log("err in file reading, ", err); - } - res.send(fileObj); - }); - } - }); - }.bind(this)); - - // log any errors that occur - form.on('error', function(err) { - res.status(500); - res.send(err); - console.log('An error has occured: \n' + err); - }); - - // once all the files have been uploaded, send a response to the client - form.on('end', function() { - // res.end("success"); - }); - // parse the incoming request containing the form data - form.parse(req); -}; \ No newline at end of file diff --git a/web-app/server/instrument.js b/web-app/server/instrument.js deleted file mode 100644 index 78897242..00000000 --- a/web-app/server/instrument.js +++ /dev/null @@ -1,140 +0,0 @@ -var fs = require('fs'); -var readline = require('readline'); -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; -var execPromise = require('./utils/execPromise').execPromise; - -exports.processInstrument = function (req, res) { - - var errorStatus = false; - var fileName = req.body.fileName; - // Extract filename without extension - fileName = fileName.replace(/\.[^/.]+$/, ""); - - var injectionMode = req.body.injectionMode; - var injectionType = req.body.injectionType; - var traceMode = req.body.traceMode; - var backwardTrace = req.body.backwardTrace; - var forwardTrace = req.body.forwardTrace; - var maxTraceCount = req.body.maxTraceCount; - var registerLocation = req.body.registerLocation; - - // Configurations for input.yaml file - var batchMode = injectionMode == "software" && injectionType.length > 1 ? true : false; - var intrumentScript = batchMode ? LLFI_BUILD_ROOT + "bin/batchInstrument --readable " + fileName + ".ll": LLFI_BUILD_ROOT + "bin/instrument -lpthread --readable " + fileName + ".ll"; - var traceEnabled = (traceMode == "fullTrace" && (backwardTrace || forwardTrace))|| tradeMode == "limitedTrace" ? true : false; - if (traceEnabled) { - var traceDirection = []; - if (forwardTrace) traceDirection.push("forward"); - if (backwardTrace) traceDirection.push("backward"); - } - // Create a stream for input.yaml file - var stream = fs.createWriteStream("./uploads/"+ req.ip +"/input.yaml"); - - // Contents of the input.yaml file - stream.once('open', function(fd) { - stream.write("kernelOption: [forceRun]\n"); - stream.write("compileOption:\n"); - stream.write(" instSelMethod:\n"); - if(injectionMode == "software") { - stream.write(" - customInstselector:\n"); - var instrumentTypeStr = " include: ["; - instrumentTypeStr += injectionType.join(", "); - instrumentTypeStr += "]\n"; - stream.write(instrumentTypeStr); - stream.write(" regSelMethod: customregselector\n"); - stream.write(" customRegSelector: Automatic\n"); - } else if (injectionMode == "hardware") { - stream.write(" - insttype:\n"); - var instrumentTypeStr = " include: ["; - instrumentTypeStr += injectionType.join(", "); - instrumentTypeStr += "]\n"; - stream.write(instrumentTypeStr); - stream.write(" regSelMethod: regloc\n"); - stream.write(" regloc: " + registerLocation + "\n"); - } - if (traceEnabled) { - var traceDirectionStr = " includeInjectionTrace: ["; - traceDirectionStr += traceDirection.join(", "); - traceDirectionStr += "]\n"; - stream.write(traceDirectionStr); - stream.write(" tracingPropagation: true\n"); - var traceOptionStr = " tracingPropagationOption: {debugTrace: True/False, generateCDFG: true"; - if (traceMode == "limitedTrace") { - traceOptionStr += ", maxTrace: " + maxTraceCount; - } - traceOptionStr += "}\n"; - stream.write(traceOptionStr); - } - stream.end(); - }); - - var cdDirCmd = "cd ./uploads/" + req.ip +"/"; - - var softwareFailureAutoScanCmd = LLFI_BUILD_ROOT + "bin/SoftwareFailureAutoScan --no_input_yaml " + fileName + ".ll"; - - var commands = []; - - commands.push(cdDirCmd + " && " + softwareFailureAutoScanCmd); - commands.push(cdDirCmd + " && " + intrumentScript); - var consoleLog = []; - var files = []; - - // Execute the instrument step - commands.reduce(function(p, cmd) { - return p.then(function(results) { - return execPromise(cmd).then(function(stdout) { - results.push(stdout); - consoleLog = results; - return results; - }); - }); - }, Promise.resolve([])).then(function(results) { - if (errorStatus) return; - if (batchMode) { - // Copy the llfi.stat.graph.doc file - fs.createReadStream("./uploads/" + req.ip +"/llfi-" + injectionType[0]+"/llfi.stat.graph.dot").pipe(fs.createWriteStream("./uploads/" + req.ip +"/llfi.stat.graph.dot")); - var indexFilePath = "./uploads/" + req.ip +"/llfi-" + injectionType[0]+"/llfi/" + fileName + "-llfi_index.ll"; - } else { - var indexFilePath = "./uploads/" + req.ip +"/llfi/" + fileName + "-llfi_index.ll"; - } - // Generate the llfi_displayIndex.ll file - var outputIndexFilePath = "./uploads/" + req.ip +"/" + fileName + "-llfi_displayIndex.ll"; - var index = 1; - fs.readFileSync(indexFilePath).toString().split('\n').forEach(function (line) { - var modifiedLine = line; - if (line.includes("!llfi_index !")) { - modifiedLine = index + "\t\t" + line.substring(0, line.indexOf("!llfi_index !")); - index ++; - fs.appendFileSync(outputIndexFilePath, modifiedLine.toString() + "\n"); - } else if (!line.includes("= metadata !")) { - modifiedLine = "\t\t" + line; - fs.appendFileSync(outputIndexFilePath, modifiedLine.toString() + "\n"); - } - }); - - // Send the llfi_displayIndex file back to front-end - fs.readFile(outputIndexFilePath, 'utf8', function(err, data) { - if (errorStatus) return; - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - var fileObj = {}; - fileObj.fileName = fileName + "-llfi_displayIndex.ll"; - fileObj.fileContent = data; - files.push(fileObj); - var result = {files: files, consoleLog: consoleLog}; - console.log("Instrument success"); - res.send(result); - }); - }, function(err) { - // error here - if (errorStatus) return; - res.status(500); - res.send({error: err}); - console.log("err in Instrument process", err); - errorStatus = true; - }); -}; \ No newline at end of file diff --git a/web-app/server/preInstrument.js b/web-app/server/preInstrument.js deleted file mode 100644 index 904ec22d..00000000 --- a/web-app/server/preInstrument.js +++ /dev/null @@ -1,50 +0,0 @@ -var fs = require('fs'); -var readline = require('readline'); -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; -var execPromise = require('./utils/execPromise').execPromise; - -// Do a hardware and software auto scan, send the applicable injection types back to the client -exports.processPreInstrument = function (req, res) { - var errorStatus = false; - var fileName = req.body.fileName; - - // Extract filename without extension - fileName = fileName.replace(/\.[^/.]+$/, ""); - - var cdDirCmd = "cd ./uploads/" + req.ip +"/"; - - var softwareFailureAutoScanCmd = LLFI_BUILD_ROOT + "bin/SoftwareFailureAutoScan --no_input_yaml " + fileName + ".ll"; - var commands = []; - - commands.push(cdDirCmd + " && " + softwareFailureAutoScanCmd); - - var softwareInjectionTypes = []; - - // Execute the auto scan script - commands.reduce(function(p, cmd) { - return p.then(function(results) { - return execPromise(cmd).then(function(stdout) { - results.push(stdout); - return results; - }); - }); - }, Promise.resolve([])).then(function(results) { - // Read software injection types - var softwareFile = "./uploads/" + req.ip +"/llfi.applicable.software.failures.txt"; - fs.readFileSync(softwareFile).toString().split('\n').forEach(function (line) { - if (line.includes("- ")) { - var injectionType = line.substring(line.indexOf("- ")+ 2); - softwareInjectionTypes.push(injectionType); - } - }); - }, function(err) { - // error here - res.status(500); - res.send({error: err}); - console.log("err in preInstrument process", err); - errorStatus = true; - }).then(function(){ - if (errorStatus) return; - res.send(softwareInjectionTypes); - }); -}; \ No newline at end of file diff --git a/web-app/server/profiling.js b/web-app/server/profiling.js deleted file mode 100644 index ba09ec52..00000000 --- a/web-app/server/profiling.js +++ /dev/null @@ -1,77 +0,0 @@ -var fs = require('fs'); -var readline = require('readline'); -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; -var execPromise = require('./utils/execPromise').execPromise; - -exports.processProfiling = function (req, res) { - - var errorStatus = false; - var fileName = req.body.fileName; - // Extract filename without extension - fileName = fileName.replace(/\.[^/.]+$/, ""); - var input = req.body.input; - var injectionMode = req.body.injectionMode.injectionMode; - var injectionType = req.body.injectionMode.injectionType; - var profilingType = injectionMode == "hardware" ? "Hardware Fault(s)" : injectionType[0]; - var batchMode = req.body.injectionMode.isBatchMode; - var profilingScript; - if (batchMode) { - profilingScript = LLFI_BUILD_ROOT + "bin/batchProfile " + fileName + ".ll " + input; - } else { - profilingScript = LLFI_BUILD_ROOT + "bin/profile " + "./llfi/" + fileName + "-profiling.exe " + input; - } - - var cdDirCmd = "cd ./uploads/" + req.ip +"/"; - var consoleLog = []; - var commands = []; - commands.push(cdDirCmd + " && " + profilingScript); - - // Execute the profiling step - commands.reduce(function(p, cmd) { - return p.then(function(results) { - return execPromise(cmd).then(function(stdout) { - results.push(stdout); - consoleLog = results; - return results; - }); - }); - }, Promise.resolve([])).then(function(results) { - if (errorStatus) return; - // Get the profiling total index status - var totalIndexFilePath = "./uploads/" + req.ip +"/" + "llfi.stat.totalindex.txt"; - fs.readFile(totalIndexFilePath, 'utf8', function(err, data) { - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - var totalIndex = parseInt(data.split("=")[1]); - // Get the lastCycle status - var profilingStatsFilePath = "./uploads/" + req.ip +"/" + "llfi.stat.prof.txt"; - fs.readFile(profilingStatsFilePath, 'utf8', function(err, data) { - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - var lastCycle = parseInt(data.split("=")[1]); - lastCycle = lastCycle == 0 ? 0 : lastCycle -1 ; - var profilingStats = [{type: profilingType, lastIndex: totalIndex, lastCycle: lastCycle}]; - console.log("Profiling success"); - var result = {profilingStats: profilingStats, consoleLog: consoleLog}; - res.send(result); - }); - }); - }, function(err) { - // error here - if (errorStatus) return; - res.status(500); - res.send({error: err}); - console.log("err in Profiling process", err); - errorStatus = true; - }); -}; \ No newline at end of file diff --git a/web-app/server/runtimeOptions.js b/web-app/server/runtimeOptions.js deleted file mode 100644 index 221eee98..00000000 --- a/web-app/server/runtimeOptions.js +++ /dev/null @@ -1,38 +0,0 @@ -var fs = require('fs'); -var readline = require('readline'); -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; - -exports.processRuntimeOptions = function (req, res) { - - var runtimeOptions = req.body.runtimeOptions; - var inputYamlFilePath = "./uploads/"+ req.ip +"/input.yaml"; - var data = ""; - if (runtimeOptions.length) data += "runOption:\n"; - for (var j = 0; j < runtimeOptions.length; j ++) { - var runOption = runtimeOptions[j]; - data += "- run: {"; - - for(var keys = Object.keys(runOption), i = 0, end = keys.length - 1; i < end; i++) { - var key = keys[i], value = runOption[key]; - data += key + ": " + value + ", "; - } - var lastIndex = Object.keys(runOption).length - 1; - if(lastIndex >= 0) { - var lastKey = Object.keys(runOption)[lastIndex]; - var value = runOption[lastKey]; - data += lastKey + ": " + value + "}\n"; - } - } - - // Append the status to input yaml file - fs.appendFile(inputYamlFilePath, data, function (err) { - if (err) { - res.status(500); - res.send(err); - console.log("err in modifying input.yaml file in runtimeOption: ", err); - } else { - console.log("runtimeOption Submit success"); - res.end(); - } - }); -}; diff --git a/web-app/server/server.js b/web-app/server/server.js deleted file mode 100644 index 04dc0121..00000000 --- a/web-app/server/server.js +++ /dev/null @@ -1,115 +0,0 @@ -var express = require('express'); -var app = express(); -var path = require('path'); -var formidable = require('formidable'); -var fs = require('fs'); -var http = require('http'); - -var Port = 8080; -var fileUpload = require('./fileUpload'); -var compileToIR = require('./compileToIR'); -var preInstrument = require('./preInstrument'); -var instrument = require('./instrument'); -var profiling = require('./profiling'); -var runtimeOptions = require('./runtimeOptions'); -var faultInjection = require('./faultInjection'); -var traceGraph = require('./traceGraph'); -var bodyParser = require('body-parser'); - -app.use(express.static(path.join(__dirname, '../views'))); -app.use(bodyParser.json()); - -app.get('/', function(req, res){ - try { - res.sendFile(path.join(__dirname, 'index.html')); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/uploadFile', function(req, res){ - try { - fileUpload.processFileUpload(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/compileToIR', function(req, res){ - try { - compileToIR.processCompileToIR(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/preInstrument', function(req, res){ - try { - preInstrument.processPreInstrument(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/instrument', function(req, res){ - try { - instrument.processInstrument(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/profiling', function(req, res){ - try { - profiling.processProfiling(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/runtimeOptions', function(req, res){ - try { - runtimeOptions.processRuntimeOptions(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/faultInjection', function(req, res){ - try { - faultInjection.processFaultInjection(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.post('/traceGraph', function(req, res){ - try { - traceGraph.processTrace(req,res); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -app.get('/tracepdf', function(req, res){ - try { - res.download("./uploads/" + req.ip +"/llfi/trace_report_output/TraceGraph.pdf"); - console.log('Trace graph sent'); - } catch (err) { - res.status(500); - res.send(err); - } -}); - -var server = app.listen(Port, function(){ - console.log('Server listening on port' + Port); -}); diff --git a/web-app/server/traceGraph.js b/web-app/server/traceGraph.js deleted file mode 100644 index 172faac0..00000000 --- a/web-app/server/traceGraph.js +++ /dev/null @@ -1,114 +0,0 @@ -var fs = require('fs'); -var readline = require('readline'); -var LLFI_BUILD_ROOT = require('./utils/config').LLFI_BUILD_ROOT; -var execPromise = require('./utils/execPromise').execPromise; - -exports.processTrace = function (req, res) { - var errorStatus = false; - - var traceRunIndex = req.body.selectedRunIndex; - - // If no trace is selected, end the resquest - if (traceRunIndex.length <= 0) { - res.status(500); - res.send({error: "No trace is selected"}); - return; - } - - var traceFolder = "./uploads/" + req.ip +"/llfi/trace_report_output/"; - // Make a dir to store the files from a client - if (!fs.existsSync(traceFolder)) { - fs.mkdirSync(traceFolder); - } - - var goldenFile = "./llfi/baseline/llfi.stat.trace.prof.txt"; - var llfi_stat_output = "./uploads/" + req.ip +"/llfi/llfi_stat_output/"; - var runNumbers = []; - var selectedTraceFileNames = []; - var traceDiffFileNames = []; - var commands = []; - var consoleLog = []; - var cdDirCmd = "cd ./uploads/" + req.ip +"/"; - // Get the number of runs in each run option - fs.readdir(llfi_stat_output, (err, files) => { - if (err) { - res.status(500); - res.send(err); - errorStatus = true; - console.log("err in file reading, ", err); - } - if (errorStatus) return; - // Get the selected Trace file names - var currentRunOptionNumber = 0; - var currentRunNumberOffset = 0; - for (var i = 0; i < traceRunIndex.length; i++) { - var traceFileName = "llfi.stat.trace." + traceRunIndex[i] + ".txt"; - if (files.indexOf(traceFileName) > -1) { - selectedTraceFileNames.push("llfi.stat.trace." + traceRunIndex[i] + ".txt"); - } - } - - // TraceDiff commands - for (var i = 0; i < selectedTraceFileNames.length; i++) { - if (files.indexOf(selectedTraceFileNames[i]) > -1) { - var nameParser = selectedTraceFileNames[i].split("llfi.stat.trace.")[1]; - var runOption = parseInt(nameParser.split("-")[0]); - var runNumber = nameParser.split("-")[1]; - runNumber = parseInt(runNumber.split(".txt")[0]); - var tradeDiffFileName = "TraceDiffReportFile." + runOption + "-" + runNumber + ".txt"; - traceDiffFileNames.push(tradeDiffFileName); - var tradeFile = "./llfi/llfi_stat_output/" + selectedTraceFileNames[i]; - var command = LLFI_BUILD_ROOT + "tools/tracediff " + goldenFile + " " + tradeFile + " > " + "./llfi/trace_report_output/" + tradeDiffFileName; - - commands.push(cdDirCmd + " && " + command); - } - } - - //Trace Union command - var traceUnionCmd = LLFI_BUILD_ROOT + "tools/traceunion "; - for (var i = 0; i < traceDiffFileNames.length; i++) { - traceUnionCmd += "./llfi/trace_report_output/" + traceDiffFileNames[i] + " "; - } - traceUnionCmd += "> ./llfi/trace_report_output/UnionedDiffReportFile.txt"; - commands.push(cdDirCmd + " && " + traceUnionCmd); - - // traceontograph command - var tracetoGraphCmd = LLFI_BUILD_ROOT + "tools/traceontograph ./llfi/trace_report_output/UnionedDiffReportFile.txt ./llfi.stat.graph.dot > ./llfi/trace_report_output/TraceGraph.dot"; - commands.push(cdDirCmd + " && " + tracetoGraphCmd); - - //Covert dot file to pdf file - var traceCovertCmd = "dot -Tpdf ./llfi/trace_report_output/TraceGraph.dot -o ./llfi/trace_report_output/TraceGraph.pdf"; - commands.push(cdDirCmd + " && " + traceCovertCmd); - - //Execute the commands - commands.reduce(function(p, cmd) { - return p.then(function(results) { - return execPromise(cmd).then(function(stdout) { - results.push(stdout); - consoleLog = results; - return results; - }, function(err) { - console.log("Trace onto graph err: ", err); - }); - }); - }, Promise.resolve([])).then(function(results) { - if (errorStatus) return; - res.send({consoleLog: consoleLog}); - }, function(err) { - // error here - if (errorStatus) return; - res.status(500); - res.send({error: err}); - console.log("err in traceGraph process", err); - errorStatus = true; - }); - }); -}; - -// Parse the file data to get the value of a status -var getStatusValue = function (statusType, fileData) { - var keyword = statusType + "="; - var value = fileData.split(keyword)[1]; - value = value.split(",")[0]; - return value; -}; \ No newline at end of file diff --git a/web-app/server/utils/config.js b/web-app/server/utils/config.js deleted file mode 100644 index 8775de27..00000000 --- a/web-app/server/utils/config.js +++ /dev/null @@ -1 +0,0 @@ -exports.LLFI_BUILD_ROOT = "$llfibuild/"; \ No newline at end of file diff --git a/web-app/server/utils/execPromise.js b/web-app/server/utils/execPromise.js deleted file mode 100644 index 11dc8030..00000000 --- a/web-app/server/utils/execPromise.js +++ /dev/null @@ -1,10 +0,0 @@ -var exec = require('child_process').exec; - -exports.execPromise = function(cmd) { - return new Promise(function(resolve, reject) { - exec(cmd, function(err, stdout) { - if (err) return reject(err); - resolve(cmd + stdout); - }); - }); -}; \ No newline at end of file diff --git a/web-app/views/bundle.min.js b/web-app/views/bundle.min.js deleted file mode 100644 index e421751f..00000000 --- a/web-app/views/bundle.min.js +++ /dev/null @@ -1,42915 +0,0 @@ -/******/ (function(modules) { // webpackBootstrap -/******/ // The module cache -/******/ var installedModules = {}; -/******/ -/******/ // The require function -/******/ function __webpack_require__(moduleId) { -/******/ -/******/ // Check if module is in cache -/******/ if(installedModules[moduleId]) -/******/ return installedModules[moduleId].exports; -/******/ -/******/ // Create a new module (and put it into the cache) -/******/ var module = installedModules[moduleId] = { -/******/ exports: {}, -/******/ id: moduleId, -/******/ loaded: false -/******/ }; -/******/ -/******/ // Execute the module function -/******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); -/******/ -/******/ // Flag the module as loaded -/******/ module.loaded = true; -/******/ -/******/ // Return the exports of the module -/******/ return module.exports; -/******/ } -/******/ -/******/ -/******/ // expose the modules object (__webpack_modules__) -/******/ __webpack_require__.m = modules; -/******/ -/******/ // expose the module cache -/******/ __webpack_require__.c = installedModules; -/******/ -/******/ // __webpack_public_path__ -/******/ __webpack_require__.p = ""; -/******/ -/******/ // Load entry module and return exports -/******/ return __webpack_require__(0); -/******/ }) -/************************************************************************/ -/******/ ([ -/* 0 */ -/***/ function(module, exports, __webpack_require__) { - - "use strict"; - - var _react = __webpack_require__(1); - - var _react2 = _interopRequireDefault(_react); - - var _reactDom = __webpack_require__(158); - - var _reactDom2 = _interopRequireDefault(_reactDom); - - var _layout = __webpack_require__(159); - - var _layout2 = _interopRequireDefault(_layout); - - function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } - - var app = document.getElementById('app'); - _reactDom2.default.render(_react2.default.createElement(_layout2.default, null), app); - -/***/ }, -/* 1 */ -/***/ function(module, exports, __webpack_require__) { - - 'use strict'; - - module.exports = __webpack_require__(2); - - -/***/ }, -/* 2 */ -/***/ function(module, exports, __webpack_require__) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule React - */ - - 'use strict'; - - var ReactDOM = __webpack_require__(3); - var ReactDOMServer = __webpack_require__(148); - var ReactIsomorphic = __webpack_require__(152); - - var assign = __webpack_require__(39); - var deprecated = __webpack_require__(157); - - // `version` will be added here by ReactIsomorphic. - var React = {}; - - assign(React, ReactIsomorphic); - - assign(React, { - // ReactDOM - findDOMNode: deprecated('findDOMNode', 'ReactDOM', 'react-dom', ReactDOM, ReactDOM.findDOMNode), - render: deprecated('render', 'ReactDOM', 'react-dom', ReactDOM, ReactDOM.render), - unmountComponentAtNode: deprecated('unmountComponentAtNode', 'ReactDOM', 'react-dom', ReactDOM, ReactDOM.unmountComponentAtNode), - - // ReactDOMServer - renderToString: deprecated('renderToString', 'ReactDOMServer', 'react-dom/server', ReactDOMServer, ReactDOMServer.renderToString), - renderToStaticMarkup: deprecated('renderToStaticMarkup', 'ReactDOMServer', 'react-dom/server', ReactDOMServer, ReactDOMServer.renderToStaticMarkup) - }); - - React.__SECRET_DOM_DO_NOT_USE_OR_YOU_WILL_BE_FIRED = ReactDOM; - React.__SECRET_DOM_SERVER_DO_NOT_USE_OR_YOU_WILL_BE_FIRED = ReactDOMServer; - - module.exports = React; - -/***/ }, -/* 3 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactDOM - */ - - /* globals __REACT_DEVTOOLS_GLOBAL_HOOK__*/ - - 'use strict'; - - var ReactCurrentOwner = __webpack_require__(5); - var ReactDOMTextComponent = __webpack_require__(6); - var ReactDefaultInjection = __webpack_require__(71); - var ReactInstanceHandles = __webpack_require__(45); - var ReactMount = __webpack_require__(28); - var ReactPerf = __webpack_require__(18); - var ReactReconciler = __webpack_require__(50); - var ReactUpdates = __webpack_require__(54); - var ReactVersion = __webpack_require__(146); - - var findDOMNode = __webpack_require__(91); - var renderSubtreeIntoContainer = __webpack_require__(147); - var warning = __webpack_require__(25); - - ReactDefaultInjection.inject(); - - var render = ReactPerf.measure('React', 'render', ReactMount.render); - - var React = { - findDOMNode: findDOMNode, - render: render, - unmountComponentAtNode: ReactMount.unmountComponentAtNode, - version: ReactVersion, - - /* eslint-disable camelcase */ - unstable_batchedUpdates: ReactUpdates.batchedUpdates, - unstable_renderSubtreeIntoContainer: renderSubtreeIntoContainer - }; - - // Inject the runtime into a devtools global hook regardless of browser. - // Allows for debugging when the hook is injected on the page. - /* eslint-enable camelcase */ - if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== 'undefined' && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.inject === 'function') { - __REACT_DEVTOOLS_GLOBAL_HOOK__.inject({ - CurrentOwner: ReactCurrentOwner, - InstanceHandles: ReactInstanceHandles, - Mount: ReactMount, - Reconciler: ReactReconciler, - TextComponent: ReactDOMTextComponent - }); - } - - if (process.env.NODE_ENV !== 'production') { - var ExecutionEnvironment = __webpack_require__(9); - if (ExecutionEnvironment.canUseDOM && window.top === window.self) { - - // First check if devtools is not installed - if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ === 'undefined') { - // If we're in Chrome or Firefox, provide a download link if not installed. - if (navigator.userAgent.indexOf('Chrome') > -1 && navigator.userAgent.indexOf('Edge') === -1 || navigator.userAgent.indexOf('Firefox') > -1) { - console.debug('Download the React DevTools for a better development experience: ' + 'https://fb.me/react-devtools'); - } - } - - // If we're in IE8, check to see if we are in compatibility mode and provide - // information on preventing compatibility mode - var ieCompatibilityMode = document.documentMode && document.documentMode < 8; - - process.env.NODE_ENV !== 'production' ? warning(!ieCompatibilityMode, 'Internet Explorer is running in compatibility mode; please add the ' + 'following tag to your HTML to prevent this from happening: ' + '') : undefined; - - var expectedFeatures = [ - // shims - Array.isArray, Array.prototype.every, Array.prototype.forEach, Array.prototype.indexOf, Array.prototype.map, Date.now, Function.prototype.bind, Object.keys, String.prototype.split, String.prototype.trim, - - // shams - Object.create, Object.freeze]; - - for (var i = 0; i < expectedFeatures.length; i++) { - if (!expectedFeatures[i]) { - console.error('One or more ES5 shim/shams expected by React are not available: ' + 'https://fb.me/react-warning-polyfills'); - break; - } - } - } - } - - module.exports = React; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 4 */ -/***/ function(module, exports) { - - // shim for using process in browser - var process = module.exports = {}; - - // cached from whatever global is present so that test runners that stub it - // don't break things. But we need to wrap it in a try catch in case it is - // wrapped in strict mode code which doesn't define any globals. It's inside a - // function because try/catches deoptimize in certain engines. - - var cachedSetTimeout; - var cachedClearTimeout; - - function defaultSetTimout() { - throw new Error('setTimeout has not been defined'); - } - function defaultClearTimeout () { - throw new Error('clearTimeout has not been defined'); - } - (function () { - try { - if (typeof setTimeout === 'function') { - cachedSetTimeout = setTimeout; - } else { - cachedSetTimeout = defaultSetTimout; - } - } catch (e) { - cachedSetTimeout = defaultSetTimout; - } - try { - if (typeof clearTimeout === 'function') { - cachedClearTimeout = clearTimeout; - } else { - cachedClearTimeout = defaultClearTimeout; - } - } catch (e) { - cachedClearTimeout = defaultClearTimeout; - } - } ()) - function runTimeout(fun) { - if (cachedSetTimeout === setTimeout) { - //normal enviroments in sane situations - return setTimeout(fun, 0); - } - // if setTimeout wasn't available but was latter defined - if ((cachedSetTimeout === defaultSetTimout || !cachedSetTimeout) && setTimeout) { - cachedSetTimeout = setTimeout; - return setTimeout(fun, 0); - } - try { - // when when somebody has screwed with setTimeout but no I.E. maddness - return cachedSetTimeout(fun, 0); - } catch(e){ - try { - // When we are in I.E. but the script has been evaled so I.E. doesn't trust the global object when called normally - return cachedSetTimeout.call(null, fun, 0); - } catch(e){ - // same as above but when it's a version of I.E. that must have the global object for 'this', hopfully our context correct otherwise it will throw a global error - return cachedSetTimeout.call(this, fun, 0); - } - } - - - } - function runClearTimeout(marker) { - if (cachedClearTimeout === clearTimeout) { - //normal enviroments in sane situations - return clearTimeout(marker); - } - // if clearTimeout wasn't available but was latter defined - if ((cachedClearTimeout === defaultClearTimeout || !cachedClearTimeout) && clearTimeout) { - cachedClearTimeout = clearTimeout; - return clearTimeout(marker); - } - try { - // when when somebody has screwed with setTimeout but no I.E. maddness - return cachedClearTimeout(marker); - } catch (e){ - try { - // When we are in I.E. but the script has been evaled so I.E. doesn't trust the global object when called normally - return cachedClearTimeout.call(null, marker); - } catch (e){ - // same as above but when it's a version of I.E. that must have the global object for 'this', hopfully our context correct otherwise it will throw a global error. - // Some versions of I.E. have different rules for clearTimeout vs setTimeout - return cachedClearTimeout.call(this, marker); - } - } - - - - } - var queue = []; - var draining = false; - var currentQueue; - var queueIndex = -1; - - function cleanUpNextTick() { - if (!draining || !currentQueue) { - return; - } - draining = false; - if (currentQueue.length) { - queue = currentQueue.concat(queue); - } else { - queueIndex = -1; - } - if (queue.length) { - drainQueue(); - } - } - - function drainQueue() { - if (draining) { - return; - } - var timeout = runTimeout(cleanUpNextTick); - draining = true; - - var len = queue.length; - while(len) { - currentQueue = queue; - queue = []; - while (++queueIndex < len) { - if (currentQueue) { - currentQueue[queueIndex].run(); - } - } - queueIndex = -1; - len = queue.length; - } - currentQueue = null; - draining = false; - runClearTimeout(timeout); - } - - process.nextTick = function (fun) { - var args = new Array(arguments.length - 1); - if (arguments.length > 1) { - for (var i = 1; i < arguments.length; i++) { - args[i - 1] = arguments[i]; - } - } - queue.push(new Item(fun, args)); - if (queue.length === 1 && !draining) { - runTimeout(drainQueue); - } - }; - - // v8 likes predictible objects - function Item(fun, array) { - this.fun = fun; - this.array = array; - } - Item.prototype.run = function () { - this.fun.apply(null, this.array); - }; - process.title = 'browser'; - process.browser = true; - process.env = {}; - process.argv = []; - process.version = ''; // empty string to avoid regexp issues - process.versions = {}; - - function noop() {} - - process.on = noop; - process.addListener = noop; - process.once = noop; - process.off = noop; - process.removeListener = noop; - process.removeAllListeners = noop; - process.emit = noop; - - process.binding = function (name) { - throw new Error('process.binding is not supported'); - }; - - process.cwd = function () { return '/' }; - process.chdir = function (dir) { - throw new Error('process.chdir is not supported'); - }; - process.umask = function() { return 0; }; - - -/***/ }, -/* 5 */ -/***/ function(module, exports) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactCurrentOwner - */ - - 'use strict'; - - /** - * Keeps track of the current owner. - * - * The current owner is the component who should own any components that are - * currently being constructed. - */ - var ReactCurrentOwner = { - - /** - * @internal - * @type {ReactComponent} - */ - current: null - - }; - - module.exports = ReactCurrentOwner; - -/***/ }, -/* 6 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactDOMTextComponent - * @typechecks static-only - */ - - 'use strict'; - - var DOMChildrenOperations = __webpack_require__(7); - var DOMPropertyOperations = __webpack_require__(22); - var ReactComponentBrowserEnvironment = __webpack_require__(26); - var ReactMount = __webpack_require__(28); - - var assign = __webpack_require__(39); - var escapeTextContentForBrowser = __webpack_require__(21); - var setTextContent = __webpack_require__(20); - var validateDOMNesting = __webpack_require__(70); - - /** - * Text nodes violate a couple assumptions that React makes about components: - * - * - When mounting text into the DOM, adjacent text nodes are merged. - * - Text nodes cannot be assigned a React root ID. - * - * This component is used to wrap strings in elements so that they can undergo - * the same reconciliation that is applied to elements. - * - * TODO: Investigate representing React components in the DOM with text nodes. - * - * @class ReactDOMTextComponent - * @extends ReactComponent - * @internal - */ - var ReactDOMTextComponent = function (props) { - // This constructor and its argument is currently used by mocks. - }; - - assign(ReactDOMTextComponent.prototype, { - - /** - * @param {ReactText} text - * @internal - */ - construct: function (text) { - // TODO: This is really a ReactText (ReactNode), not a ReactElement - this._currentElement = text; - this._stringText = '' + text; - - // Properties - this._rootNodeID = null; - this._mountIndex = 0; - }, - - /** - * Creates the markup for this text node. This node is not intended to have - * any features besides containing text content. - * - * @param {string} rootID DOM ID of the root node. - * @param {ReactReconcileTransaction|ReactServerRenderingTransaction} transaction - * @return {string} Markup for this text node. - * @internal - */ - mountComponent: function (rootID, transaction, context) { - if (process.env.NODE_ENV !== 'production') { - if (context[validateDOMNesting.ancestorInfoContextKey]) { - validateDOMNesting('span', null, context[validateDOMNesting.ancestorInfoContextKey]); - } - } - - this._rootNodeID = rootID; - if (transaction.useCreateElement) { - var ownerDocument = context[ReactMount.ownerDocumentContextKey]; - var el = ownerDocument.createElement('span'); - DOMPropertyOperations.setAttributeForID(el, rootID); - // Populate node cache - ReactMount.getID(el); - setTextContent(el, this._stringText); - return el; - } else { - var escapedText = escapeTextContentForBrowser(this._stringText); - - if (transaction.renderToStaticMarkup) { - // Normally we'd wrap this in a `span` for the reasons stated above, but - // since this is a situation where React won't take over (static pages), - // we can simply return the text as it is. - return escapedText; - } - - return '' + escapedText + ''; - } - }, - - /** - * Updates this component by updating the text content. - * - * @param {ReactText} nextText The next text content - * @param {ReactReconcileTransaction} transaction - * @internal - */ - receiveComponent: function (nextText, transaction) { - if (nextText !== this._currentElement) { - this._currentElement = nextText; - var nextStringText = '' + nextText; - if (nextStringText !== this._stringText) { - // TODO: Save this as pending props and use performUpdateIfNecessary - // and/or updateComponent to do the actual update for consistency with - // other component types? - this._stringText = nextStringText; - var node = ReactMount.getNode(this._rootNodeID); - DOMChildrenOperations.updateTextContent(node, nextStringText); - } - } - }, - - unmountComponent: function () { - ReactComponentBrowserEnvironment.unmountIDFromEnvironment(this._rootNodeID); - } - - }); - - module.exports = ReactDOMTextComponent; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 7 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule DOMChildrenOperations - * @typechecks static-only - */ - - 'use strict'; - - var Danger = __webpack_require__(8); - var ReactMultiChildUpdateTypes = __webpack_require__(16); - var ReactPerf = __webpack_require__(18); - - var setInnerHTML = __webpack_require__(19); - var setTextContent = __webpack_require__(20); - var invariant = __webpack_require__(13); - - /** - * Inserts `childNode` as a child of `parentNode` at the `index`. - * - * @param {DOMElement} parentNode Parent node in which to insert. - * @param {DOMElement} childNode Child node to insert. - * @param {number} index Index at which to insert the child. - * @internal - */ - function insertChildAt(parentNode, childNode, index) { - // By exploiting arrays returning `undefined` for an undefined index, we can - // rely exclusively on `insertBefore(node, null)` instead of also using - // `appendChild(node)`. However, using `undefined` is not allowed by all - // browsers so we must replace it with `null`. - - // fix render order error in safari - // IE8 will throw error when index out of list size. - var beforeChild = index >= parentNode.childNodes.length ? null : parentNode.childNodes.item(index); - - parentNode.insertBefore(childNode, beforeChild); - } - - /** - * Operations for updating with DOM children. - */ - var DOMChildrenOperations = { - - dangerouslyReplaceNodeWithMarkup: Danger.dangerouslyReplaceNodeWithMarkup, - - updateTextContent: setTextContent, - - /** - * Updates a component's children by processing a series of updates. The - * update configurations are each expected to have a `parentNode` property. - * - * @param {array} updates List of update configurations. - * @param {array} markupList List of markup strings. - * @internal - */ - processUpdates: function (updates, markupList) { - var update; - // Mapping from parent IDs to initial child orderings. - var initialChildren = null; - // List of children that will be moved or removed. - var updatedChildren = null; - - for (var i = 0; i < updates.length; i++) { - update = updates[i]; - if (update.type === ReactMultiChildUpdateTypes.MOVE_EXISTING || update.type === ReactMultiChildUpdateTypes.REMOVE_NODE) { - var updatedIndex = update.fromIndex; - var updatedChild = update.parentNode.childNodes[updatedIndex]; - var parentID = update.parentID; - - !updatedChild ? process.env.NODE_ENV !== 'production' ? invariant(false, 'processUpdates(): Unable to find child %s of element. This ' + 'probably means the DOM was unexpectedly mutated (e.g., by the ' + 'browser), usually due to forgetting a when using tables, ' + 'nesting tags like
,

, or , or using non-SVG elements ' + 'in an parent. Try inspecting the child nodes of the element ' + 'with React ID `%s`.', updatedIndex, parentID) : invariant(false) : undefined; - - initialChildren = initialChildren || {}; - initialChildren[parentID] = initialChildren[parentID] || []; - initialChildren[parentID][updatedIndex] = updatedChild; - - updatedChildren = updatedChildren || []; - updatedChildren.push(updatedChild); - } - } - - var renderedMarkup; - // markupList is either a list of markup or just a list of elements - if (markupList.length && typeof markupList[0] === 'string') { - renderedMarkup = Danger.dangerouslyRenderMarkup(markupList); - } else { - renderedMarkup = markupList; - } - - // Remove updated children first so that `toIndex` is consistent. - if (updatedChildren) { - for (var j = 0; j < updatedChildren.length; j++) { - updatedChildren[j].parentNode.removeChild(updatedChildren[j]); - } - } - - for (var k = 0; k < updates.length; k++) { - update = updates[k]; - switch (update.type) { - case ReactMultiChildUpdateTypes.INSERT_MARKUP: - insertChildAt(update.parentNode, renderedMarkup[update.markupIndex], update.toIndex); - break; - case ReactMultiChildUpdateTypes.MOVE_EXISTING: - insertChildAt(update.parentNode, initialChildren[update.parentID][update.fromIndex], update.toIndex); - break; - case ReactMultiChildUpdateTypes.SET_MARKUP: - setInnerHTML(update.parentNode, update.content); - break; - case ReactMultiChildUpdateTypes.TEXT_CONTENT: - setTextContent(update.parentNode, update.content); - break; - case ReactMultiChildUpdateTypes.REMOVE_NODE: - // Already removed by the for-loop above. - break; - } - } - } - - }; - - ReactPerf.measureMethods(DOMChildrenOperations, 'DOMChildrenOperations', { - updateTextContent: 'updateTextContent' - }); - - module.exports = DOMChildrenOperations; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 8 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule Danger - * @typechecks static-only - */ - - 'use strict'; - - var ExecutionEnvironment = __webpack_require__(9); - - var createNodesFromMarkup = __webpack_require__(10); - var emptyFunction = __webpack_require__(15); - var getMarkupWrap = __webpack_require__(14); - var invariant = __webpack_require__(13); - - var OPEN_TAG_NAME_EXP = /^(<[^ \/>]+)/; - var RESULT_INDEX_ATTR = 'data-danger-index'; - - /** - * Extracts the `nodeName` from a string of markup. - * - * NOTE: Extracting the `nodeName` does not require a regular expression match - * because we make assumptions about React-generated markup (i.e. there are no - * spaces surrounding the opening tag and there is at least one attribute). - * - * @param {string} markup String of markup. - * @return {string} Node name of the supplied markup. - * @see http://jsperf.com/extract-nodename - */ - function getNodeName(markup) { - return markup.substring(1, markup.indexOf(' ')); - } - - var Danger = { - - /** - * Renders markup into an array of nodes. The markup is expected to render - * into a list of root nodes. Also, the length of `resultList` and - * `markupList` should be the same. - * - * @param {array} markupList List of markup strings to render. - * @return {array} List of rendered nodes. - * @internal - */ - dangerouslyRenderMarkup: function (markupList) { - !ExecutionEnvironment.canUseDOM ? process.env.NODE_ENV !== 'production' ? invariant(false, 'dangerouslyRenderMarkup(...): Cannot render markup in a worker ' + 'thread. Make sure `window` and `document` are available globally ' + 'before requiring React when unit testing or use ' + 'ReactDOMServer.renderToString for server rendering.') : invariant(false) : undefined; - var nodeName; - var markupByNodeName = {}; - // Group markup by `nodeName` if a wrap is necessary, else by '*'. - for (var i = 0; i < markupList.length; i++) { - !markupList[i] ? process.env.NODE_ENV !== 'production' ? invariant(false, 'dangerouslyRenderMarkup(...): Missing markup.') : invariant(false) : undefined; - nodeName = getNodeName(markupList[i]); - nodeName = getMarkupWrap(nodeName) ? nodeName : '*'; - markupByNodeName[nodeName] = markupByNodeName[nodeName] || []; - markupByNodeName[nodeName][i] = markupList[i]; - } - var resultList = []; - var resultListAssignmentCount = 0; - for (nodeName in markupByNodeName) { - if (!markupByNodeName.hasOwnProperty(nodeName)) { - continue; - } - var markupListByNodeName = markupByNodeName[nodeName]; - - // This for-in loop skips the holes of the sparse array. The order of - // iteration should follow the order of assignment, which happens to match - // numerical index order, but we don't rely on that. - var resultIndex; - for (resultIndex in markupListByNodeName) { - if (markupListByNodeName.hasOwnProperty(resultIndex)) { - var markup = markupListByNodeName[resultIndex]; - - // Push the requested markup with an additional RESULT_INDEX_ATTR - // attribute. If the markup does not start with a < character, it - // will be discarded below (with an appropriate console.error). - markupListByNodeName[resultIndex] = markup.replace(OPEN_TAG_NAME_EXP, - // This index will be parsed back out below. - '$1 ' + RESULT_INDEX_ATTR + '="' + resultIndex + '" '); - } - } - - // Render each group of markup with similar wrapping `nodeName`. - var renderNodes = createNodesFromMarkup(markupListByNodeName.join(''), emptyFunction // Do nothing special with ', '
']; - var trWrap = [3, '', '
']; - - var svgWrap = [1, '', '']; - - var markupWrap = { - '*': [1, '?

'], - - 'area': [1, '', ''], - 'col': [2, '', '
'], - 'legend': [1, '
', '
'], - 'param': [1, '', ''], - 'tr': [2, '', '
'], - - 'optgroup': selectWrap, - 'option': selectWrap, - - 'caption': tableWrap, - 'colgroup': tableWrap, - 'tbody': tableWrap, - 'tfoot': tableWrap, - 'thead': tableWrap, - - 'td': trWrap, - 'th': trWrap - }; - - // Initialize the SVG elements since we know they'll always need to be wrapped - // consistently. If they are created inside a
they will be initialized in - // the wrong namespace (and will not display). - var svgElements = ['circle', 'clipPath', 'defs', 'ellipse', 'g', 'image', 'line', 'linearGradient', 'mask', 'path', 'pattern', 'polygon', 'polyline', 'radialGradient', 'rect', 'stop', 'text', 'tspan']; - svgElements.forEach(function (nodeName) { - markupWrap[nodeName] = svgWrap; - shouldWrap[nodeName] = true; - }); - - /** - * Gets the markup wrap configuration for the supplied `nodeName`. - * - * NOTE: This lazily detects which wraps are necessary for the current browser. - * - * @param {string} nodeName Lowercase `nodeName`. - * @return {?array} Markup wrap configuration, if applicable. - */ - function getMarkupWrap(nodeName) { - !!!dummyNode ? process.env.NODE_ENV !== 'production' ? invariant(false, 'Markup wrapping node not initialized') : invariant(false) : undefined; - if (!markupWrap.hasOwnProperty(nodeName)) { - nodeName = '*'; - } - if (!shouldWrap.hasOwnProperty(nodeName)) { - if (nodeName === '*') { - dummyNode.innerHTML = ''; - } else { - dummyNode.innerHTML = '<' + nodeName + '>'; - } - shouldWrap[nodeName] = !dummyNode.firstChild; - } - return shouldWrap[nodeName] ? markupWrap[nodeName] : null; - } - - module.exports = getMarkupWrap; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 15 */ -/***/ function(module, exports) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule emptyFunction - */ - - "use strict"; - - function makeEmptyFunction(arg) { - return function () { - return arg; - }; - } - - /** - * This function accepts and discards inputs; it has no side effects. This is - * primarily useful idiomatically for overridable function endpoints which - * always need to be callable, since JS lacks a null-call idiom ala Cocoa. - */ - function emptyFunction() {} - - emptyFunction.thatReturns = makeEmptyFunction; - emptyFunction.thatReturnsFalse = makeEmptyFunction(false); - emptyFunction.thatReturnsTrue = makeEmptyFunction(true); - emptyFunction.thatReturnsNull = makeEmptyFunction(null); - emptyFunction.thatReturnsThis = function () { - return this; - }; - emptyFunction.thatReturnsArgument = function (arg) { - return arg; - }; - - module.exports = emptyFunction; - -/***/ }, -/* 16 */ -/***/ function(module, exports, __webpack_require__) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactMultiChildUpdateTypes - */ - - 'use strict'; - - var keyMirror = __webpack_require__(17); - - /** - * When a component's children are updated, a series of update configuration - * objects are created in order to batch and serialize the required changes. - * - * Enumerates all the possible types of update configurations. - * - * @internal - */ - var ReactMultiChildUpdateTypes = keyMirror({ - INSERT_MARKUP: null, - MOVE_EXISTING: null, - REMOVE_NODE: null, - SET_MARKUP: null, - TEXT_CONTENT: null - }); - - module.exports = ReactMultiChildUpdateTypes; - -/***/ }, -/* 17 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule keyMirror - * @typechecks static-only - */ - - 'use strict'; - - var invariant = __webpack_require__(13); - - /** - * Constructs an enumeration with keys equal to their value. - * - * For example: - * - * var COLORS = keyMirror({blue: null, red: null}); - * var myColor = COLORS.blue; - * var isColorValid = !!COLORS[myColor]; - * - * The last line could not be performed if the values of the generated enum were - * not equal to their keys. - * - * Input: {key1: val1, key2: val2} - * Output: {key1: key1, key2: key2} - * - * @param {object} obj - * @return {object} - */ - var keyMirror = function (obj) { - var ret = {}; - var key; - !(obj instanceof Object && !Array.isArray(obj)) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'keyMirror(...): Argument must be an object.') : invariant(false) : undefined; - for (key in obj) { - if (!obj.hasOwnProperty(key)) { - continue; - } - ret[key] = key; - } - return ret; - }; - - module.exports = keyMirror; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 18 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactPerf - * @typechecks static-only - */ - - 'use strict'; - - /** - * ReactPerf is a general AOP system designed to measure performance. This - * module only has the hooks: see ReactDefaultPerf for the analysis tool. - */ - var ReactPerf = { - /** - * Boolean to enable/disable measurement. Set to false by default to prevent - * accidental logging and perf loss. - */ - enableMeasure: false, - - /** - * Holds onto the measure function in use. By default, don't measure - * anything, but we'll override this if we inject a measure function. - */ - storedMeasure: _noMeasure, - - /** - * @param {object} object - * @param {string} objectName - * @param {object} methodNames - */ - measureMethods: function (object, objectName, methodNames) { - if (process.env.NODE_ENV !== 'production') { - for (var key in methodNames) { - if (!methodNames.hasOwnProperty(key)) { - continue; - } - object[key] = ReactPerf.measure(objectName, methodNames[key], object[key]); - } - } - }, - - /** - * Use this to wrap methods you want to measure. Zero overhead in production. - * - * @param {string} objName - * @param {string} fnName - * @param {function} func - * @return {function} - */ - measure: function (objName, fnName, func) { - if (process.env.NODE_ENV !== 'production') { - var measuredFunc = null; - var wrapper = function () { - if (ReactPerf.enableMeasure) { - if (!measuredFunc) { - measuredFunc = ReactPerf.storedMeasure(objName, fnName, func); - } - return measuredFunc.apply(this, arguments); - } - return func.apply(this, arguments); - }; - wrapper.displayName = objName + '_' + fnName; - return wrapper; - } - return func; - }, - - injection: { - /** - * @param {function} measure - */ - injectMeasure: function (measure) { - ReactPerf.storedMeasure = measure; - } - } - }; - - /** - * Simply passes through the measured function, without measuring it. - * - * @param {string} objName - * @param {string} fnName - * @param {function} func - * @return {function} - */ - function _noMeasure(objName, fnName, func) { - return func; - } - - module.exports = ReactPerf; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 19 */ -/***/ function(module, exports, __webpack_require__) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule setInnerHTML - */ - - /* globals MSApp */ - - 'use strict'; - - var ExecutionEnvironment = __webpack_require__(9); - - var WHITESPACE_TEST = /^[ \r\n\t\f]/; - var NONVISIBLE_TEST = /<(!--|link|noscript|meta|script|style)[ \r\n\t\f\/>]/; - - /** - * Set the innerHTML property of a node, ensuring that whitespace is preserved - * even in IE8. - * - * @param {DOMElement} node - * @param {string} html - * @internal - */ - var setInnerHTML = function (node, html) { - node.innerHTML = html; - }; - - // Win8 apps: Allow all html to be inserted - if (typeof MSApp !== 'undefined' && MSApp.execUnsafeLocalFunction) { - setInnerHTML = function (node, html) { - MSApp.execUnsafeLocalFunction(function () { - node.innerHTML = html; - }); - }; - } - - if (ExecutionEnvironment.canUseDOM) { - // IE8: When updating a just created node with innerHTML only leading - // whitespace is removed. When updating an existing node with innerHTML - // whitespace in root TextNodes is also collapsed. - // @see quirksmode.org/bugreports/archives/2004/11/innerhtml_and_t.html - - // Feature detection; only IE8 is known to behave improperly like this. - var testElement = document.createElement('div'); - testElement.innerHTML = ' '; - if (testElement.innerHTML === '') { - setInnerHTML = function (node, html) { - // Magic theory: IE8 supposedly differentiates between added and updated - // nodes when processing innerHTML, innerHTML on updated nodes suffers - // from worse whitespace behavior. Re-adding a node like this triggers - // the initial and more favorable whitespace behavior. - // TODO: What to do on a detached node? - if (node.parentNode) { - node.parentNode.replaceChild(node, node); - } - - // We also implement a workaround for non-visible tags disappearing into - // thin air on IE8, this only happens if there is no visible text - // in-front of the non-visible tags. Piggyback on the whitespace fix - // and simply check if any non-visible tags appear in the source. - if (WHITESPACE_TEST.test(html) || html[0] === '<' && NONVISIBLE_TEST.test(html)) { - // Recover leading whitespace by temporarily prepending any character. - // \uFEFF has the potential advantage of being zero-width/invisible. - // UglifyJS drops U+FEFF chars when parsing, so use String.fromCharCode - // in hopes that this is preserved even if "\uFEFF" is transformed to - // the actual Unicode character (by Babel, for example). - // https://github.com/mishoo/UglifyJS2/blob/v2.4.20/lib/parse.js#L216 - node.innerHTML = String.fromCharCode(0xFEFF) + html; - - // deleteData leaves an empty `TextNode` which offsets the index of all - // children. Definitely want to avoid this. - var textNode = node.firstChild; - if (textNode.data.length === 1) { - node.removeChild(textNode); - } else { - textNode.deleteData(0, 1); - } - } else { - node.innerHTML = html; - } - }; - } - } - - module.exports = setInnerHTML; - -/***/ }, -/* 20 */ -/***/ function(module, exports, __webpack_require__) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule setTextContent - */ - - 'use strict'; - - var ExecutionEnvironment = __webpack_require__(9); - var escapeTextContentForBrowser = __webpack_require__(21); - var setInnerHTML = __webpack_require__(19); - - /** - * Set the textContent property of a node, ensuring that whitespace is preserved - * even in IE8. innerText is a poor substitute for textContent and, among many - * issues, inserts
instead of the literal newline chars. innerHTML behaves - * as it should. - * - * @param {DOMElement} node - * @param {string} text - * @internal - */ - var setTextContent = function (node, text) { - node.textContent = text; - }; - - if (ExecutionEnvironment.canUseDOM) { - if (!('textContent' in document.documentElement)) { - setTextContent = function (node, text) { - setInnerHTML(node, escapeTextContentForBrowser(text)); - }; - } - } - - module.exports = setTextContent; - -/***/ }, -/* 21 */ -/***/ function(module, exports) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule escapeTextContentForBrowser - */ - - 'use strict'; - - var ESCAPE_LOOKUP = { - '&': '&', - '>': '>', - '<': '<', - '"': '"', - '\'': ''' - }; - - var ESCAPE_REGEX = /[&><"']/g; - - function escaper(match) { - return ESCAPE_LOOKUP[match]; - } - - /** - * Escapes text to prevent scripting attacks. - * - * @param {*} text Text value to escape. - * @return {string} An escaped string. - */ - function escapeTextContentForBrowser(text) { - return ('' + text).replace(ESCAPE_REGEX, escaper); - } - - module.exports = escapeTextContentForBrowser; - -/***/ }, -/* 22 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule DOMPropertyOperations - * @typechecks static-only - */ - - 'use strict'; - - var DOMProperty = __webpack_require__(23); - var ReactPerf = __webpack_require__(18); - - var quoteAttributeValueForBrowser = __webpack_require__(24); - var warning = __webpack_require__(25); - - // Simplified subset - var VALID_ATTRIBUTE_NAME_REGEX = /^[a-zA-Z_][\w\.\-]*$/; - var illegalAttributeNameCache = {}; - var validatedAttributeNameCache = {}; - - function isAttributeNameSafe(attributeName) { - if (validatedAttributeNameCache.hasOwnProperty(attributeName)) { - return true; - } - if (illegalAttributeNameCache.hasOwnProperty(attributeName)) { - return false; - } - if (VALID_ATTRIBUTE_NAME_REGEX.test(attributeName)) { - validatedAttributeNameCache[attributeName] = true; - return true; - } - illegalAttributeNameCache[attributeName] = true; - process.env.NODE_ENV !== 'production' ? warning(false, 'Invalid attribute name: `%s`', attributeName) : undefined; - return false; - } - - function shouldIgnoreValue(propertyInfo, value) { - return value == null || propertyInfo.hasBooleanValue && !value || propertyInfo.hasNumericValue && isNaN(value) || propertyInfo.hasPositiveNumericValue && value < 1 || propertyInfo.hasOverloadedBooleanValue && value === false; - } - - if (process.env.NODE_ENV !== 'production') { - var reactProps = { - children: true, - dangerouslySetInnerHTML: true, - key: true, - ref: true - }; - var warnedProperties = {}; - - var warnUnknownProperty = function (name) { - if (reactProps.hasOwnProperty(name) && reactProps[name] || warnedProperties.hasOwnProperty(name) && warnedProperties[name]) { - return; - } - - warnedProperties[name] = true; - var lowerCasedName = name.toLowerCase(); - - // data-* attributes should be lowercase; suggest the lowercase version - var standardName = DOMProperty.isCustomAttribute(lowerCasedName) ? lowerCasedName : DOMProperty.getPossibleStandardName.hasOwnProperty(lowerCasedName) ? DOMProperty.getPossibleStandardName[lowerCasedName] : null; - - // For now, only warn when we have a suggested correction. This prevents - // logging too much when using transferPropsTo. - process.env.NODE_ENV !== 'production' ? warning(standardName == null, 'Unknown DOM property %s. Did you mean %s?', name, standardName) : undefined; - }; - } - - /** - * Operations for dealing with DOM properties. - */ - var DOMPropertyOperations = { - - /** - * Creates markup for the ID property. - * - * @param {string} id Unescaped ID. - * @return {string} Markup string. - */ - createMarkupForID: function (id) { - return DOMProperty.ID_ATTRIBUTE_NAME + '=' + quoteAttributeValueForBrowser(id); - }, - - setAttributeForID: function (node, id) { - node.setAttribute(DOMProperty.ID_ATTRIBUTE_NAME, id); - }, - - /** - * Creates markup for a property. - * - * @param {string} name - * @param {*} value - * @return {?string} Markup string, or null if the property was invalid. - */ - createMarkupForProperty: function (name, value) { - var propertyInfo = DOMProperty.properties.hasOwnProperty(name) ? DOMProperty.properties[name] : null; - if (propertyInfo) { - if (shouldIgnoreValue(propertyInfo, value)) { - return ''; - } - var attributeName = propertyInfo.attributeName; - if (propertyInfo.hasBooleanValue || propertyInfo.hasOverloadedBooleanValue && value === true) { - return attributeName + '=""'; - } - return attributeName + '=' + quoteAttributeValueForBrowser(value); - } else if (DOMProperty.isCustomAttribute(name)) { - if (value == null) { - return ''; - } - return name + '=' + quoteAttributeValueForBrowser(value); - } else if (process.env.NODE_ENV !== 'production') { - warnUnknownProperty(name); - } - return null; - }, - - /** - * Creates markup for a custom property. - * - * @param {string} name - * @param {*} value - * @return {string} Markup string, or empty string if the property was invalid. - */ - createMarkupForCustomAttribute: function (name, value) { - if (!isAttributeNameSafe(name) || value == null) { - return ''; - } - return name + '=' + quoteAttributeValueForBrowser(value); - }, - - /** - * Sets the value for a property on a node. - * - * @param {DOMElement} node - * @param {string} name - * @param {*} value - */ - setValueForProperty: function (node, name, value) { - var propertyInfo = DOMProperty.properties.hasOwnProperty(name) ? DOMProperty.properties[name] : null; - if (propertyInfo) { - var mutationMethod = propertyInfo.mutationMethod; - if (mutationMethod) { - mutationMethod(node, value); - } else if (shouldIgnoreValue(propertyInfo, value)) { - this.deleteValueForProperty(node, name); - } else if (propertyInfo.mustUseAttribute) { - var attributeName = propertyInfo.attributeName; - var namespace = propertyInfo.attributeNamespace; - // `setAttribute` with objects becomes only `[object]` in IE8/9, - // ('' + value) makes it output the correct toString()-value. - if (namespace) { - node.setAttributeNS(namespace, attributeName, '' + value); - } else if (propertyInfo.hasBooleanValue || propertyInfo.hasOverloadedBooleanValue && value === true) { - node.setAttribute(attributeName, ''); - } else { - node.setAttribute(attributeName, '' + value); - } - } else { - var propName = propertyInfo.propertyName; - // Must explicitly cast values for HAS_SIDE_EFFECTS-properties to the - // property type before comparing; only `value` does and is string. - if (!propertyInfo.hasSideEffects || '' + node[propName] !== '' + value) { - // Contrary to `setAttribute`, object properties are properly - // `toString`ed by IE8/9. - node[propName] = value; - } - } - } else if (DOMProperty.isCustomAttribute(name)) { - DOMPropertyOperations.setValueForAttribute(node, name, value); - } else if (process.env.NODE_ENV !== 'production') { - warnUnknownProperty(name); - } - }, - - setValueForAttribute: function (node, name, value) { - if (!isAttributeNameSafe(name)) { - return; - } - if (value == null) { - node.removeAttribute(name); - } else { - node.setAttribute(name, '' + value); - } - }, - - /** - * Deletes the value for a property on a node. - * - * @param {DOMElement} node - * @param {string} name - */ - deleteValueForProperty: function (node, name) { - var propertyInfo = DOMProperty.properties.hasOwnProperty(name) ? DOMProperty.properties[name] : null; - if (propertyInfo) { - var mutationMethod = propertyInfo.mutationMethod; - if (mutationMethod) { - mutationMethod(node, undefined); - } else if (propertyInfo.mustUseAttribute) { - node.removeAttribute(propertyInfo.attributeName); - } else { - var propName = propertyInfo.propertyName; - var defaultValue = DOMProperty.getDefaultValueForProperty(node.nodeName, propName); - if (!propertyInfo.hasSideEffects || '' + node[propName] !== defaultValue) { - node[propName] = defaultValue; - } - } - } else if (DOMProperty.isCustomAttribute(name)) { - node.removeAttribute(name); - } else if (process.env.NODE_ENV !== 'production') { - warnUnknownProperty(name); - } - } - - }; - - ReactPerf.measureMethods(DOMPropertyOperations, 'DOMPropertyOperations', { - setValueForProperty: 'setValueForProperty', - setValueForAttribute: 'setValueForAttribute', - deleteValueForProperty: 'deleteValueForProperty' - }); - - module.exports = DOMPropertyOperations; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 23 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule DOMProperty - * @typechecks static-only - */ - - 'use strict'; - - var invariant = __webpack_require__(13); - - function checkMask(value, bitmask) { - return (value & bitmask) === bitmask; - } - - var DOMPropertyInjection = { - /** - * Mapping from normalized, camelcased property names to a configuration that - * specifies how the associated DOM property should be accessed or rendered. - */ - MUST_USE_ATTRIBUTE: 0x1, - MUST_USE_PROPERTY: 0x2, - HAS_SIDE_EFFECTS: 0x4, - HAS_BOOLEAN_VALUE: 0x8, - HAS_NUMERIC_VALUE: 0x10, - HAS_POSITIVE_NUMERIC_VALUE: 0x20 | 0x10, - HAS_OVERLOADED_BOOLEAN_VALUE: 0x40, - - /** - * Inject some specialized knowledge about the DOM. This takes a config object - * with the following properties: - * - * isCustomAttribute: function that given an attribute name will return true - * if it can be inserted into the DOM verbatim. Useful for data-* or aria-* - * attributes where it's impossible to enumerate all of the possible - * attribute names, - * - * Properties: object mapping DOM property name to one of the - * DOMPropertyInjection constants or null. If your attribute isn't in here, - * it won't get written to the DOM. - * - * DOMAttributeNames: object mapping React attribute name to the DOM - * attribute name. Attribute names not specified use the **lowercase** - * normalized name. - * - * DOMAttributeNamespaces: object mapping React attribute name to the DOM - * attribute namespace URL. (Attribute names not specified use no namespace.) - * - * DOMPropertyNames: similar to DOMAttributeNames but for DOM properties. - * Property names not specified use the normalized name. - * - * DOMMutationMethods: Properties that require special mutation methods. If - * `value` is undefined, the mutation method should unset the property. - * - * @param {object} domPropertyConfig the config as described above. - */ - injectDOMPropertyConfig: function (domPropertyConfig) { - var Injection = DOMPropertyInjection; - var Properties = domPropertyConfig.Properties || {}; - var DOMAttributeNamespaces = domPropertyConfig.DOMAttributeNamespaces || {}; - var DOMAttributeNames = domPropertyConfig.DOMAttributeNames || {}; - var DOMPropertyNames = domPropertyConfig.DOMPropertyNames || {}; - var DOMMutationMethods = domPropertyConfig.DOMMutationMethods || {}; - - if (domPropertyConfig.isCustomAttribute) { - DOMProperty._isCustomAttributeFunctions.push(domPropertyConfig.isCustomAttribute); - } - - for (var propName in Properties) { - !!DOMProperty.properties.hasOwnProperty(propName) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'injectDOMPropertyConfig(...): You\'re trying to inject DOM property ' + '\'%s\' which has already been injected. You may be accidentally ' + 'injecting the same DOM property config twice, or you may be ' + 'injecting two configs that have conflicting property names.', propName) : invariant(false) : undefined; - - var lowerCased = propName.toLowerCase(); - var propConfig = Properties[propName]; - - var propertyInfo = { - attributeName: lowerCased, - attributeNamespace: null, - propertyName: propName, - mutationMethod: null, - - mustUseAttribute: checkMask(propConfig, Injection.MUST_USE_ATTRIBUTE), - mustUseProperty: checkMask(propConfig, Injection.MUST_USE_PROPERTY), - hasSideEffects: checkMask(propConfig, Injection.HAS_SIDE_EFFECTS), - hasBooleanValue: checkMask(propConfig, Injection.HAS_BOOLEAN_VALUE), - hasNumericValue: checkMask(propConfig, Injection.HAS_NUMERIC_VALUE), - hasPositiveNumericValue: checkMask(propConfig, Injection.HAS_POSITIVE_NUMERIC_VALUE), - hasOverloadedBooleanValue: checkMask(propConfig, Injection.HAS_OVERLOADED_BOOLEAN_VALUE) - }; - - !(!propertyInfo.mustUseAttribute || !propertyInfo.mustUseProperty) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'DOMProperty: Cannot require using both attribute and property: %s', propName) : invariant(false) : undefined; - !(propertyInfo.mustUseProperty || !propertyInfo.hasSideEffects) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'DOMProperty: Properties that have side effects must use property: %s', propName) : invariant(false) : undefined; - !(propertyInfo.hasBooleanValue + propertyInfo.hasNumericValue + propertyInfo.hasOverloadedBooleanValue <= 1) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'DOMProperty: Value can be one of boolean, overloaded boolean, or ' + 'numeric value, but not a combination: %s', propName) : invariant(false) : undefined; - - if (process.env.NODE_ENV !== 'production') { - DOMProperty.getPossibleStandardName[lowerCased] = propName; - } - - if (DOMAttributeNames.hasOwnProperty(propName)) { - var attributeName = DOMAttributeNames[propName]; - propertyInfo.attributeName = attributeName; - if (process.env.NODE_ENV !== 'production') { - DOMProperty.getPossibleStandardName[attributeName] = propName; - } - } - - if (DOMAttributeNamespaces.hasOwnProperty(propName)) { - propertyInfo.attributeNamespace = DOMAttributeNamespaces[propName]; - } - - if (DOMPropertyNames.hasOwnProperty(propName)) { - propertyInfo.propertyName = DOMPropertyNames[propName]; - } - - if (DOMMutationMethods.hasOwnProperty(propName)) { - propertyInfo.mutationMethod = DOMMutationMethods[propName]; - } - - DOMProperty.properties[propName] = propertyInfo; - } - } - }; - var defaultValueCache = {}; - - /** - * DOMProperty exports lookup objects that can be used like functions: - * - * > DOMProperty.isValid['id'] - * true - * > DOMProperty.isValid['foobar'] - * undefined - * - * Although this may be confusing, it performs better in general. - * - * @see http://jsperf.com/key-exists - * @see http://jsperf.com/key-missing - */ - var DOMProperty = { - - ID_ATTRIBUTE_NAME: 'data-reactid', - - /** - * Map from property "standard name" to an object with info about how to set - * the property in the DOM. Each object contains: - * - * attributeName: - * Used when rendering markup or with `*Attribute()`. - * attributeNamespace - * propertyName: - * Used on DOM node instances. (This includes properties that mutate due to - * external factors.) - * mutationMethod: - * If non-null, used instead of the property or `setAttribute()` after - * initial render. - * mustUseAttribute: - * Whether the property must be accessed and mutated using `*Attribute()`. - * (This includes anything that fails ` in `.) - * mustUseProperty: - * Whether the property must be accessed and mutated as an object property. - * hasSideEffects: - * Whether or not setting a value causes side effects such as triggering - * resources to be loaded or text selection changes. If true, we read from - * the DOM before updating to ensure that the value is only set if it has - * changed. - * hasBooleanValue: - * Whether the property should be removed when set to a falsey value. - * hasNumericValue: - * Whether the property must be numeric or parse as a numeric and should be - * removed when set to a falsey value. - * hasPositiveNumericValue: - * Whether the property must be positive numeric or parse as a positive - * numeric and should be removed when set to a falsey value. - * hasOverloadedBooleanValue: - * Whether the property can be used as a flag as well as with a value. - * Removed when strictly equal to false; present without a value when - * strictly equal to true; present with a value otherwise. - */ - properties: {}, - - /** - * Mapping from lowercase property names to the properly cased version, used - * to warn in the case of missing properties. Available only in __DEV__. - * @type {Object} - */ - getPossibleStandardName: process.env.NODE_ENV !== 'production' ? {} : null, - - /** - * All of the isCustomAttribute() functions that have been injected. - */ - _isCustomAttributeFunctions: [], - - /** - * Checks whether a property name is a custom attribute. - * @method - */ - isCustomAttribute: function (attributeName) { - for (var i = 0; i < DOMProperty._isCustomAttributeFunctions.length; i++) { - var isCustomAttributeFn = DOMProperty._isCustomAttributeFunctions[i]; - if (isCustomAttributeFn(attributeName)) { - return true; - } - } - return false; - }, - - /** - * Returns the default property value for a DOM property (i.e., not an - * attribute). Most default values are '' or false, but not all. Worse yet, - * some (in particular, `type`) vary depending on the type of element. - * - * TODO: Is it better to grab all the possible properties when creating an - * element to avoid having to create the same element twice? - */ - getDefaultValueForProperty: function (nodeName, prop) { - var nodeDefaults = defaultValueCache[nodeName]; - var testElement; - if (!nodeDefaults) { - defaultValueCache[nodeName] = nodeDefaults = {}; - } - if (!(prop in nodeDefaults)) { - testElement = document.createElement(nodeName); - nodeDefaults[prop] = testElement[prop]; - } - return nodeDefaults[prop]; - }, - - injection: DOMPropertyInjection - }; - - module.exports = DOMProperty; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 24 */ -/***/ function(module, exports, __webpack_require__) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule quoteAttributeValueForBrowser - */ - - 'use strict'; - - var escapeTextContentForBrowser = __webpack_require__(21); - - /** - * Escapes attribute value to prevent scripting attacks. - * - * @param {*} value Value to escape. - * @return {string} An escaped string. - */ - function quoteAttributeValueForBrowser(value) { - return '"' + escapeTextContentForBrowser(value) + '"'; - } - - module.exports = quoteAttributeValueForBrowser; - -/***/ }, -/* 25 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2014-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule warning - */ - - 'use strict'; - - var emptyFunction = __webpack_require__(15); - - /** - * Similar to invariant but only logs a warning if the condition is not met. - * This can be used to log issues in development environments in critical - * paths. Removing the logging code for production environments will keep the - * same logic and follow the same code paths. - */ - - var warning = emptyFunction; - - if (process.env.NODE_ENV !== 'production') { - warning = function (condition, format) { - for (var _len = arguments.length, args = Array(_len > 2 ? _len - 2 : 0), _key = 2; _key < _len; _key++) { - args[_key - 2] = arguments[_key]; - } - - if (format === undefined) { - throw new Error('`warning(condition, format, ...args)` requires a warning ' + 'message argument'); - } - - if (format.indexOf('Failed Composite propType: ') === 0) { - return; // Ignore CompositeComponent proptype check. - } - - if (!condition) { - var argIndex = 0; - var message = 'Warning: ' + format.replace(/%s/g, function () { - return args[argIndex++]; - }); - if (typeof console !== 'undefined') { - console.error(message); - } - try { - // --- Welcome to debugging React --- - // This error was thrown as a convenience so that you can use this stack - // to find the callsite that caused this warning to fire. - throw new Error(message); - } catch (x) {} - } - }; - } - - module.exports = warning; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 26 */ -/***/ function(module, exports, __webpack_require__) { - - /** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactComponentBrowserEnvironment - */ - - 'use strict'; - - var ReactDOMIDOperations = __webpack_require__(27); - var ReactMount = __webpack_require__(28); - - /** - * Abstracts away all functionality of the reconciler that requires knowledge of - * the browser context. TODO: These callers should be refactored to avoid the - * need for this injection. - */ - var ReactComponentBrowserEnvironment = { - - processChildrenUpdates: ReactDOMIDOperations.dangerouslyProcessChildrenUpdates, - - replaceNodeWithMarkupByID: ReactDOMIDOperations.dangerouslyReplaceNodeWithMarkupByID, - - /** - * If a particular environment requires that some resources be cleaned up, - * specify this in the injected Mixin. In the DOM, we would likely want to - * purge any cached node ID lookups. - * - * @private - */ - unmountIDFromEnvironment: function (rootNodeID) { - ReactMount.purgeID(rootNodeID); - } - - }; - - module.exports = ReactComponentBrowserEnvironment; - -/***/ }, -/* 27 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactDOMIDOperations - * @typechecks static-only - */ - - 'use strict'; - - var DOMChildrenOperations = __webpack_require__(7); - var DOMPropertyOperations = __webpack_require__(22); - var ReactMount = __webpack_require__(28); - var ReactPerf = __webpack_require__(18); - - var invariant = __webpack_require__(13); - - /** - * Errors for properties that should not be updated with `updatePropertyByID()`. - * - * @type {object} - * @private - */ - var INVALID_PROPERTY_ERRORS = { - dangerouslySetInnerHTML: '`dangerouslySetInnerHTML` must be set using `updateInnerHTMLByID()`.', - style: '`style` must be set using `updateStylesByID()`.' - }; - - /** - * Operations used to process updates to DOM nodes. - */ - var ReactDOMIDOperations = { - - /** - * Updates a DOM node with new property values. This should only be used to - * update DOM properties in `DOMProperty`. - * - * @param {string} id ID of the node to update. - * @param {string} name A valid property name, see `DOMProperty`. - * @param {*} value New value of the property. - * @internal - */ - updatePropertyByID: function (id, name, value) { - var node = ReactMount.getNode(id); - !!INVALID_PROPERTY_ERRORS.hasOwnProperty(name) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'updatePropertyByID(...): %s', INVALID_PROPERTY_ERRORS[name]) : invariant(false) : undefined; - - // If we're updating to null or undefined, we should remove the property - // from the DOM node instead of inadvertantly setting to a string. This - // brings us in line with the same behavior we have on initial render. - if (value != null) { - DOMPropertyOperations.setValueForProperty(node, name, value); - } else { - DOMPropertyOperations.deleteValueForProperty(node, name); - } - }, - - /** - * Replaces a DOM node that exists in the document with markup. - * - * @param {string} id ID of child to be replaced. - * @param {string} markup Dangerous markup to inject in place of child. - * @internal - * @see {Danger.dangerouslyReplaceNodeWithMarkup} - */ - dangerouslyReplaceNodeWithMarkupByID: function (id, markup) { - var node = ReactMount.getNode(id); - DOMChildrenOperations.dangerouslyReplaceNodeWithMarkup(node, markup); - }, - - /** - * Updates a component's children by processing a series of updates. - * - * @param {array} updates List of update configurations. - * @param {array} markup List of markup strings. - * @internal - */ - dangerouslyProcessChildrenUpdates: function (updates, markup) { - for (var i = 0; i < updates.length; i++) { - updates[i].parentNode = ReactMount.getNode(updates[i].parentID); - } - DOMChildrenOperations.processUpdates(updates, markup); - } - }; - - ReactPerf.measureMethods(ReactDOMIDOperations, 'ReactDOMIDOperations', { - dangerouslyReplaceNodeWithMarkupByID: 'dangerouslyReplaceNodeWithMarkupByID', - dangerouslyProcessChildrenUpdates: 'dangerouslyProcessChildrenUpdates' - }); - - module.exports = ReactDOMIDOperations; - /* WEBPACK VAR INJECTION */}.call(exports, __webpack_require__(4))) - -/***/ }, -/* 28 */ -/***/ function(module, exports, __webpack_require__) { - - /* WEBPACK VAR INJECTION */(function(process) {/** - * Copyright 2013-2015, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - * @providesModule ReactMount - */ - - 'use strict'; - - var DOMProperty = __webpack_require__(23); - var ReactBrowserEventEmitter = __webpack_require__(29); - var ReactCurrentOwner = __webpack_require__(5); - var ReactDOMFeatureFlags = __webpack_require__(41); - var ReactElement = __webpack_require__(42); - var ReactEmptyComponentRegistry = __webpack_require__(44); - var ReactInstanceHandles = __webpack_require__(45); - var ReactInstanceMap = __webpack_require__(47); - var ReactMarkupChecksum = __webpack_require__(48); - var ReactPerf = __webpack_require__(18); - var ReactReconciler = __webpack_require__(50); - var ReactUpdateQueue = __webpack_require__(53); - var ReactUpdates = __webpack_require__(54); - - var assign = __webpack_require__(39); - var emptyObject = __webpack_require__(58); - var containsNode = __webpack_require__(59); - var instantiateReactComponent = __webpack_require__(62); - var invariant = __webpack_require__(13); - var setInnerHTML = __webpack_require__(19); - var shouldUpdateReactComponent = __webpack_require__(67); - var validateDOMNesting = __webpack_require__(70); - var warning = __webpack_require__(25); - - var ATTR_NAME = DOMProperty.ID_ATTRIBUTE_NAME; - var nodeCache = {}; - - var ELEMENT_NODE_TYPE = 1; - var DOC_NODE_TYPE = 9; - var DOCUMENT_FRAGMENT_NODE_TYPE = 11; - - var ownerDocumentContextKey = '__ReactMount_ownerDocument$' + Math.random().toString(36).slice(2); - - /** Mapping from reactRootID to React component instance. */ - var instancesByReactRootID = {}; - - /** Mapping from reactRootID to `container` nodes. */ - var containersByReactRootID = {}; - - if (process.env.NODE_ENV !== 'production') { - /** __DEV__-only mapping from reactRootID to root elements. */ - var rootElementsByReactRootID = {}; - } - - // Used to store breadth-first search state in findComponentRoot. - var findComponentRootReusableArray = []; - - /** - * Finds the index of the first character - * that's not common between the two given strings. - * - * @return {number} the index of the character where the strings diverge - */ - function firstDifferenceIndex(string1, string2) { - var minLen = Math.min(string1.length, string2.length); - for (var i = 0; i < minLen; i++) { - if (string1.charAt(i) !== string2.charAt(i)) { - return i; - } - } - return string1.length === string2.length ? -1 : minLen; - } - - /** - * @param {DOMElement|DOMDocument} container DOM element that may contain - * a React component - * @return {?*} DOM element that may have the reactRoot ID, or null. - */ - function getReactRootElementInContainer(container) { - if (!container) { - return null; - } - - if (container.nodeType === DOC_NODE_TYPE) { - return container.documentElement; - } else { - return container.firstChild; - } - } - - /** - * @param {DOMElement} container DOM element that may contain a React component. - * @return {?string} A "reactRoot" ID, if a React component is rendered. - */ - function getReactRootID(container) { - var rootElement = getReactRootElementInContainer(container); - return rootElement && ReactMount.getID(rootElement); - } - - /** - * Accessing node[ATTR_NAME] or calling getAttribute(ATTR_NAME) on a form - * element can return its control whose name or ID equals ATTR_NAME. All - * DOM nodes support `getAttributeNode` but this can also get called on - * other objects so just return '' if we're given something other than a - * DOM node (such as window). - * - * @param {?DOMElement|DOMWindow|DOMDocument|DOMTextNode} node DOM node. - * @return {string} ID of the supplied `domNode`. - */ - function getID(node) { - var id = internalGetID(node); - if (id) { - if (nodeCache.hasOwnProperty(id)) { - var cached = nodeCache[id]; - if (cached !== node) { - !!isValid(cached, id) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'ReactMount: Two valid but unequal nodes with the same `%s`: %s', ATTR_NAME, id) : invariant(false) : undefined; - - nodeCache[id] = node; - } - } else { - nodeCache[id] = node; - } - } - - return id; - } - - function internalGetID(node) { - // If node is something like a window, document, or text node, none of - // which support attributes or a .getAttribute method, gracefully return - // the empty string, as if the attribute were missing. - return node && node.getAttribute && node.getAttribute(ATTR_NAME) || ''; - } - - /** - * Sets the React-specific ID of the given node. - * - * @param {DOMElement} node The DOM node whose ID will be set. - * @param {string} id The value of the ID attribute. - */ - function setID(node, id) { - var oldID = internalGetID(node); - if (oldID !== id) { - delete nodeCache[oldID]; - } - node.setAttribute(ATTR_NAME, id); - nodeCache[id] = node; - } - - /** - * Finds the node with the supplied React-generated DOM ID. - * - * @param {string} id A React-generated DOM ID. - * @return {DOMElement} DOM node with the suppled `id`. - * @internal - */ - function getNode(id) { - if (!nodeCache.hasOwnProperty(id) || !isValid(nodeCache[id], id)) { - nodeCache[id] = ReactMount.findReactNodeByID(id); - } - return nodeCache[id]; - } - - /** - * Finds the node with the supplied public React instance. - * - * @param {*} instance A public React instance. - * @return {?DOMElement} DOM node with the suppled `id`. - * @internal - */ - function getNodeFromInstance(instance) { - var id = ReactInstanceMap.get(instance)._rootNodeID; - if (ReactEmptyComponentRegistry.isNullComponentID(id)) { - return null; - } - if (!nodeCache.hasOwnProperty(id) || !isValid(nodeCache[id], id)) { - nodeCache[id] = ReactMount.findReactNodeByID(id); - } - return nodeCache[id]; - } - - /** - * A node is "valid" if it is contained by a currently mounted container. - * - * This means that the node does not have to be contained by a document in - * order to be considered valid. - * - * @param {?DOMElement} node The candidate DOM node. - * @param {string} id The expected ID of the node. - * @return {boolean} Whether the node is contained by a mounted container. - */ - function isValid(node, id) { - if (node) { - !(internalGetID(node) === id) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'ReactMount: Unexpected modification of `%s`', ATTR_NAME) : invariant(false) : undefined; - - var container = ReactMount.findReactContainerForID(id); - if (container && containsNode(container, node)) { - return true; - } - } - - return false; - } - - /** - * Causes the cache to forget about one React-specific ID. - * - * @param {string} id The ID to forget. - */ - function purgeID(id) { - delete nodeCache[id]; - } - - var deepestNodeSoFar = null; - function findDeepestCachedAncestorImpl(ancestorID) { - var ancestor = nodeCache[ancestorID]; - if (ancestor && isValid(ancestor, ancestorID)) { - deepestNodeSoFar = ancestor; - } else { - // This node isn't populated in the cache, so presumably none of its - // descendants are. Break out of the loop. - return false; - } - } - - /** - * Return the deepest cached node whose ID is a prefix of `targetID`. - */ - function findDeepestCachedAncestor(targetID) { - deepestNodeSoFar = null; - ReactInstanceHandles.traverseAncestors(targetID, findDeepestCachedAncestorImpl); - - var foundNode = deepestNodeSoFar; - deepestNodeSoFar = null; - return foundNode; - } - - /** - * Mounts this component and inserts it into the DOM. - * - * @param {ReactComponent} componentInstance The instance to mount. - * @param {string} rootID DOM ID of the root node. - * @param {DOMElement} container DOM element to mount into. - * @param {ReactReconcileTransaction} transaction - * @param {boolean} shouldReuseMarkup If true, do not insert markup - */ - function mountComponentIntoNode(componentInstance, rootID, container, transaction, shouldReuseMarkup, context) { - if (ReactDOMFeatureFlags.useCreateElement) { - context = assign({}, context); - if (container.nodeType === DOC_NODE_TYPE) { - context[ownerDocumentContextKey] = container; - } else { - context[ownerDocumentContextKey] = container.ownerDocument; - } - } - if (process.env.NODE_ENV !== 'production') { - if (context === emptyObject) { - context = {}; - } - var tag = container.nodeName.toLowerCase(); - context[validateDOMNesting.ancestorInfoContextKey] = validateDOMNesting.updatedAncestorInfo(null, tag, null); - } - var markup = ReactReconciler.mountComponent(componentInstance, rootID, transaction, context); - componentInstance._renderedComponent._topLevelWrapper = componentInstance; - ReactMount._mountImageIntoNode(markup, container, shouldReuseMarkup, transaction); - } - - /** - * Batched mount. - * - * @param {ReactComponent} componentInstance The instance to mount. - * @param {string} rootID DOM ID of the root node. - * @param {DOMElement} container DOM element to mount into. - * @param {boolean} shouldReuseMarkup If true, do not insert markup - */ - function batchedMountComponentIntoNode(componentInstance, rootID, container, shouldReuseMarkup, context) { - var transaction = ReactUpdates.ReactReconcileTransaction.getPooled( - /* forceHTML */shouldReuseMarkup); - transaction.perform(mountComponentIntoNode, null, componentInstance, rootID, container, transaction, shouldReuseMarkup, context); - ReactUpdates.ReactReconcileTransaction.release(transaction); - } - - /** - * Unmounts a component and removes it from the DOM. - * - * @param {ReactComponent} instance React component instance. - * @param {DOMElement} container DOM element to unmount from. - * @final - * @internal - * @see {ReactMount.unmountComponentAtNode} - */ - function unmountComponentFromNode(instance, container) { - ReactReconciler.unmountComponent(instance); - - if (container.nodeType === DOC_NODE_TYPE) { - container = container.documentElement; - } - - // http://jsperf.com/emptying-a-node - while (container.lastChild) { - container.removeChild(container.lastChild); - } - } - - /** - * True if the supplied DOM node has a direct React-rendered child that is - * not a React root element. Useful for warning in `render`, - * `unmountComponentAtNode`, etc. - * - * @param {?DOMElement} node The candidate DOM node. - * @return {boolean} True if the DOM element contains a direct child that was - * rendered by React but is not a root element. - * @internal - */ - function hasNonRootReactChild(node) { - var reactRootID = getReactRootID(node); - return reactRootID ? reactRootID !== ReactInstanceHandles.getReactRootIDFromNodeID(reactRootID) : false; - } - - /** - * Returns the first (deepest) ancestor of a node which is rendered by this copy - * of React. - */ - function findFirstReactDOMImpl(node) { - // This node might be from another React instance, so we make sure not to - // examine the node cache here - for (; node && node.parentNode !== node; node = node.parentNode) { - if (node.nodeType !== 1) { - // Not a DOMElement, therefore not a React component - continue; - } - var nodeID = internalGetID(node); - if (!nodeID) { - continue; - } - var reactRootID = ReactInstanceHandles.getReactRootIDFromNodeID(nodeID); - - // If containersByReactRootID contains the container we find by crawling up - // the tree, we know that this instance of React rendered the node. - // nb. isValid's strategy (with containsNode) does not work because render - // trees may be nested and we don't want a false positive in that case. - var current = node; - var lastID; - do { - lastID = internalGetID(current); - current = current.parentNode; - if (current == null) { - // The passed-in node has been detached from the container it was - // originally rendered into. - return null; - } - } while (lastID !== reactRootID); - - if (current === containersByReactRootID[reactRootID]) { - return node; - } - } - return null; - } - - /** - * Temporary (?) hack so that we can store all top-level pending updates on - * composites instead of having to worry about different types of components - * here. - */ - var TopLevelWrapper = function () {}; - TopLevelWrapper.prototype.isReactComponent = {}; - if (process.env.NODE_ENV !== 'production') { - TopLevelWrapper.displayName = 'TopLevelWrapper'; - } - TopLevelWrapper.prototype.render = function () { - // this.props is actually a ReactElement - return this.props; - }; - - /** - * Mounting is the process of initializing a React component by creating its - * representative DOM elements and inserting them into a supplied `container`. - * Any prior content inside `container` is destroyed in the process. - * - * ReactMount.render( - * component, - * document.getElementById('container') - * ); - * - *
<-- Supplied `container`. - *
<-- Rendered reactRoot of React - * // ... component. - *
- *
- * - * Inside of `container`, the first element rendered is the "reactRoot". - */ - var ReactMount = { - - TopLevelWrapper: TopLevelWrapper, - - /** Exposed for debugging purposes **/ - _instancesByReactRootID: instancesByReactRootID, - - /** - * This is a hook provided to support rendering React components while - * ensuring that the apparent scroll position of its `container` does not - * change. - * - * @param {DOMElement} container The `container` being rendered into. - * @param {function} renderCallback This must be called once to do the render. - */ - scrollMonitor: function (container, renderCallback) { - renderCallback(); - }, - - /** - * Take a component that's already mounted into the DOM and replace its props - * @param {ReactComponent} prevComponent component instance already in the DOM - * @param {ReactElement} nextElement component instance to render - * @param {DOMElement} container container to render into - * @param {?function} callback function triggered on completion - */ - _updateRootComponent: function (prevComponent, nextElement, container, callback) { - ReactMount.scrollMonitor(container, function () { - ReactUpdateQueue.enqueueElementInternal(prevComponent, nextElement); - if (callback) { - ReactUpdateQueue.enqueueCallbackInternal(prevComponent, callback); - } - }); - - if (process.env.NODE_ENV !== 'production') { - // Record the root element in case it later gets transplanted. - rootElementsByReactRootID[getReactRootID(container)] = getReactRootElementInContainer(container); - } - - return prevComponent; - }, - - /** - * Register a component into the instance map and starts scroll value - * monitoring - * @param {ReactComponent} nextComponent component instance to render - * @param {DOMElement} container container to render into - * @return {string} reactRoot ID prefix - */ - _registerComponent: function (nextComponent, container) { - !(container && (container.nodeType === ELEMENT_NODE_TYPE || container.nodeType === DOC_NODE_TYPE || container.nodeType === DOCUMENT_FRAGMENT_NODE_TYPE)) ? process.env.NODE_ENV !== 'production' ? invariant(false, '_registerComponent(...): Target container is not a DOM element.') : invariant(false) : undefined; - - ReactBrowserEventEmitter.ensureScrollValueMonitoring(); - - var reactRootID = ReactMount.registerContainer(container); - instancesByReactRootID[reactRootID] = nextComponent; - return reactRootID; - }, - - /** - * Render a new component into the DOM. - * @param {ReactElement} nextElement element to render - * @param {DOMElement} container container to render into - * @param {boolean} shouldReuseMarkup if we should skip the markup insertion - * @return {ReactComponent} nextComponent - */ - _renderNewRootComponent: function (nextElement, container, shouldReuseMarkup, context) { - // Various parts of our code (such as ReactCompositeComponent's - // _renderValidatedComponent) assume that calls to render aren't nested; - // verify that that's the case. - process.env.NODE_ENV !== 'production' ? warning(ReactCurrentOwner.current == null, '_renderNewRootComponent(): Render methods should be a pure function ' + 'of props and state; triggering nested component updates from ' + 'render is not allowed. If necessary, trigger nested updates in ' + 'componentDidUpdate. Check the render method of %s.', ReactCurrentOwner.current && ReactCurrentOwner.current.getName() || 'ReactCompositeComponent') : undefined; - - var componentInstance = instantiateReactComponent(nextElement, null); - var reactRootID = ReactMount._registerComponent(componentInstance, container); - - // The initial render is synchronous but any updates that happen during - // rendering, in componentWillMount or componentDidMount, will be batched - // according to the current batching strategy. - - ReactUpdates.batchedUpdates(batchedMountComponentIntoNode, componentInstance, reactRootID, container, shouldReuseMarkup, context); - - if (process.env.NODE_ENV !== 'production') { - // Record the root element in case it later gets transplanted. - rootElementsByReactRootID[reactRootID] = getReactRootElementInContainer(container); - } - - return componentInstance; - }, - - /** - * Renders a React component into the DOM in the supplied `container`. - * - * If the React component was previously rendered into `container`, this will - * perform an update on it and only mutate the DOM as necessary to reflect the - * latest React component. - * - * @param {ReactComponent} parentComponent The conceptual parent of this render tree. - * @param {ReactElement} nextElement Component element to render. - * @param {DOMElement} container DOM element to render into. - * @param {?function} callback function triggered on completion - * @return {ReactComponent} Component instance rendered in `container`. - */ - renderSubtreeIntoContainer: function (parentComponent, nextElement, container, callback) { - !(parentComponent != null && parentComponent._reactInternalInstance != null) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'parentComponent must be a valid React Component') : invariant(false) : undefined; - return ReactMount._renderSubtreeIntoContainer(parentComponent, nextElement, container, callback); - }, - - _renderSubtreeIntoContainer: function (parentComponent, nextElement, container, callback) { - !ReactElement.isValidElement(nextElement) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'ReactDOM.render(): Invalid component element.%s', typeof nextElement === 'string' ? ' Instead of passing an element string, make sure to instantiate ' + 'it by passing it to React.createElement.' : typeof nextElement === 'function' ? ' Instead of passing a component class, make sure to instantiate ' + 'it by passing it to React.createElement.' : - // Check if it quacks like an element - nextElement != null && nextElement.props !== undefined ? ' This may be caused by unintentionally loading two independent ' + 'copies of React.' : '') : invariant(false) : undefined; - - process.env.NODE_ENV !== 'production' ? warning(!container || !container.tagName || container.tagName.toUpperCase() !== 'BODY', 'render(): Rendering components directly into document.body is ' + 'discouraged, since its children are often manipulated by third-party ' + 'scripts and browser extensions. This may lead to subtle ' + 'reconciliation issues. Try rendering into a container element created ' + 'for your app.') : undefined; - - var nextWrappedElement = new ReactElement(TopLevelWrapper, null, null, null, null, null, nextElement); - - var prevComponent = instancesByReactRootID[getReactRootID(container)]; - - if (prevComponent) { - var prevWrappedElement = prevComponent._currentElement; - var prevElement = prevWrappedElement.props; - if (shouldUpdateReactComponent(prevElement, nextElement)) { - var publicInst = prevComponent._renderedComponent.getPublicInstance(); - var updatedCallback = callback && function () { - callback.call(publicInst); - }; - ReactMount._updateRootComponent(prevComponent, nextWrappedElement, container, updatedCallback); - return publicInst; - } else { - ReactMount.unmountComponentAtNode(container); - } - } - - var reactRootElement = getReactRootElementInContainer(container); - var containerHasReactMarkup = reactRootElement && !!internalGetID(reactRootElement); - var containerHasNonRootReactChild = hasNonRootReactChild(container); - - if (process.env.NODE_ENV !== 'production') { - process.env.NODE_ENV !== 'production' ? warning(!containerHasNonRootReactChild, 'render(...): Replacing React-rendered children with a new root ' + 'component. If you intended to update the children of this node, ' + 'you should instead have the existing children update their state ' + 'and render the new components instead of calling ReactDOM.render.') : undefined; - - if (!containerHasReactMarkup || reactRootElement.nextSibling) { - var rootElementSibling = reactRootElement; - while (rootElementSibling) { - if (internalGetID(rootElementSibling)) { - process.env.NODE_ENV !== 'production' ? warning(false, 'render(): Target node has markup rendered by React, but there ' + 'are unrelated nodes as well. This is most commonly caused by ' + 'white-space inserted around server-rendered markup.') : undefined; - break; - } - rootElementSibling = rootElementSibling.nextSibling; - } - } - } - - var shouldReuseMarkup = containerHasReactMarkup && !prevComponent && !containerHasNonRootReactChild; - var component = ReactMount._renderNewRootComponent(nextWrappedElement, container, shouldReuseMarkup, parentComponent != null ? parentComponent._reactInternalInstance._processChildContext(parentComponent._reactInternalInstance._context) : emptyObject)._renderedComponent.getPublicInstance(); - if (callback) { - callback.call(component); - } - return component; - }, - - /** - * Renders a React component into the DOM in the supplied `container`. - * - * If the React component was previously rendered into `container`, this will - * perform an update on it and only mutate the DOM as necessary to reflect the - * latest React component. - * - * @param {ReactElement} nextElement Component element to render. - * @param {DOMElement} container DOM element to render into. - * @param {?function} callback function triggered on completion - * @return {ReactComponent} Component instance rendered in `container`. - */ - render: function (nextElement, container, callback) { - return ReactMount._renderSubtreeIntoContainer(null, nextElement, container, callback); - }, - - /** - * Registers a container node into which React components will be rendered. - * This also creates the "reactRoot" ID that will be assigned to the element - * rendered within. - * - * @param {DOMElement} container DOM element to register as a container. - * @return {string} The "reactRoot" ID of elements rendered within. - */ - registerContainer: function (container) { - var reactRootID = getReactRootID(container); - if (reactRootID) { - // If one exists, make sure it is a valid "reactRoot" ID. - reactRootID = ReactInstanceHandles.getReactRootIDFromNodeID(reactRootID); - } - if (!reactRootID) { - // No valid "reactRoot" ID found, create one. - reactRootID = ReactInstanceHandles.createReactRootID(); - } - containersByReactRootID[reactRootID] = container; - return reactRootID; - }, - - /** - * Unmounts and destroys the React component rendered in the `container`. - * - * @param {DOMElement} container DOM element containing a React component. - * @return {boolean} True if a component was found in and unmounted from - * `container` - */ - unmountComponentAtNode: function (container) { - // Various parts of our code (such as ReactCompositeComponent's - // _renderValidatedComponent) assume that calls to render aren't nested; - // verify that that's the case. (Strictly speaking, unmounting won't cause a - // render but we still don't expect to be in a render call here.) - process.env.NODE_ENV !== 'production' ? warning(ReactCurrentOwner.current == null, 'unmountComponentAtNode(): Render methods should be a pure function ' + 'of props and state; triggering nested component updates from render ' + 'is not allowed. If necessary, trigger nested updates in ' + 'componentDidUpdate. Check the render method of %s.', ReactCurrentOwner.current && ReactCurrentOwner.current.getName() || 'ReactCompositeComponent') : undefined; - - !(container && (container.nodeType === ELEMENT_NODE_TYPE || container.nodeType === DOC_NODE_TYPE || container.nodeType === DOCUMENT_FRAGMENT_NODE_TYPE)) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'unmountComponentAtNode(...): Target container is not a DOM element.') : invariant(false) : undefined; - - var reactRootID = getReactRootID(container); - var component = instancesByReactRootID[reactRootID]; - if (!component) { - // Check if the node being unmounted was rendered by React, but isn't a - // root node. - var containerHasNonRootReactChild = hasNonRootReactChild(container); - - // Check if the container itself is a React root node. - var containerID = internalGetID(container); - var isContainerReactRoot = containerID && containerID === ReactInstanceHandles.getReactRootIDFromNodeID(containerID); - - if (process.env.NODE_ENV !== 'production') { - process.env.NODE_ENV !== 'production' ? warning(!containerHasNonRootReactChild, 'unmountComponentAtNode(): The node you\'re attempting to unmount ' + 'was rendered by React and is not a top-level container. %s', isContainerReactRoot ? 'You may have accidentally passed in a React root node instead ' + 'of its container.' : 'Instead, have the parent component update its state and ' + 'rerender in order to remove this component.') : undefined; - } - - return false; - } - ReactUpdates.batchedUpdates(unmountComponentFromNode, component, container); - delete instancesByReactRootID[reactRootID]; - delete containersByReactRootID[reactRootID]; - if (process.env.NODE_ENV !== 'production') { - delete rootElementsByReactRootID[reactRootID]; - } - return true; - }, - - /** - * Finds the container DOM element that contains React component to which the - * supplied DOM `id` belongs. - * - * @param {string} id The ID of an element rendered by a React component. - * @return {?DOMElement} DOM element that contains the `id`. - */ - findReactContainerForID: function (id) { - var reactRootID = ReactInstanceHandles.getReactRootIDFromNodeID(id); - var container = containersByReactRootID[reactRootID]; - - if (process.env.NODE_ENV !== 'production') { - var rootElement = rootElementsByReactRootID[reactRootID]; - if (rootElement && rootElement.parentNode !== container) { - process.env.NODE_ENV !== 'production' ? warning( - // Call internalGetID here because getID calls isValid which calls - // findReactContainerForID (this function). - internalGetID(rootElement) === reactRootID, 'ReactMount: Root element ID differed from reactRootID.') : undefined; - var containerChild = container.firstChild; - if (containerChild && reactRootID === internalGetID(containerChild)) { - // If the container has a new child with the same ID as the old - // root element, then rootElementsByReactRootID[reactRootID] is - // just stale and needs to be updated. The case that deserves a - // warning is when the container is empty. - rootElementsByReactRootID[reactRootID] = containerChild; - } else { - process.env.NODE_ENV !== 'production' ? warning(false, 'ReactMount: Root element has been removed from its original ' + 'container. New container: %s', rootElement.parentNode) : undefined; - } - } - } - - return container; - }, - - /** - * Finds an element rendered by React with the supplied ID. - * - * @param {string} id ID of a DOM node in the React component. - * @return {DOMElement} Root DOM node of the React component. - */ - findReactNodeByID: function (id) { - var reactRoot = ReactMount.findReactContainerForID(id); - return ReactMount.findComponentRoot(reactRoot, id); - }, - - /** - * Traverses up the ancestors of the supplied node to find a node that is a - * DOM representation of a React component rendered by this copy of React. - * - * @param {*} node - * @return {?DOMEventTarget} - * @internal - */ - getFirstReactDOM: function (node) { - return findFirstReactDOMImpl(node); - }, - - /** - * Finds a node with the supplied `targetID` inside of the supplied - * `ancestorNode`. Exploits the ID naming scheme to perform the search - * quickly. - * - * @param {DOMEventTarget} ancestorNode Search from this root. - * @pararm {string} targetID ID of the DOM representation of the component. - * @return {DOMEventTarget} DOM node with the supplied `targetID`. - * @internal - */ - findComponentRoot: function (ancestorNode, targetID) { - var firstChildren = findComponentRootReusableArray; - var childIndex = 0; - - var deepestAncestor = findDeepestCachedAncestor(targetID) || ancestorNode; - - if (process.env.NODE_ENV !== 'production') { - // This will throw on the next line; give an early warning - process.env.NODE_ENV !== 'production' ? warning(deepestAncestor != null, 'React can\'t find the root component node for data-reactid value ' + '`%s`. If you\'re seeing this message, it probably means that ' + 'you\'ve loaded two copies of React on the page. At this time, only ' + 'a single copy of React can be loaded at a time.', targetID) : undefined; - } - - firstChildren[0] = deepestAncestor.firstChild; - firstChildren.length = 1; - - while (childIndex < firstChildren.length) { - var child = firstChildren[childIndex++]; - var targetChild; - - while (child) { - var childID = ReactMount.getID(child); - if (childID) { - // Even if we find the node we're looking for, we finish looping - // through its siblings to ensure they're cached so that we don't have - // to revisit this node again. Otherwise, we make n^2 calls to getID - // when visiting the many children of a single node in order. - - if (targetID === childID) { - targetChild = child; - } else if (ReactInstanceHandles.isAncestorIDOf(childID, targetID)) { - // If we find a child whose ID is an ancestor of the given ID, - // then we can be sure that we only want to search the subtree - // rooted at this child, so we can throw out the rest of the - // search state. - firstChildren.length = childIndex = 0; - firstChildren.push(child.firstChild); - } - } else { - // If this child had no ID, then there's a chance that it was - // injected automatically by the browser, as when a `` - // element sprouts an extra `` child as a side effect of - // `.innerHTML` parsing. Optimistically continue down this - // branch, but not before examining the other siblings. - firstChildren.push(child.firstChild); - } - - child = child.nextSibling; - } - - if (targetChild) { - // Emptying firstChildren/findComponentRootReusableArray is - // not necessary for correctness, but it helps the GC reclaim - // any nodes that were left at the end of the search. - firstChildren.length = 0; - - return targetChild; - } - } - - firstChildren.length = 0; - - true ? process.env.NODE_ENV !== 'production' ? invariant(false, 'findComponentRoot(..., %s): Unable to find element. This probably ' + 'means the DOM was unexpectedly mutated (e.g., by the browser), ' + 'usually due to forgetting a when using tables, nesting tags ' + 'like ,

, or , or using non-SVG elements in an ' + 'parent. ' + 'Try inspecting the child nodes of the element with React ID `%s`.', targetID, ReactMount.getID(ancestorNode)) : invariant(false) : undefined; - }, - - _mountImageIntoNode: function (markup, container, shouldReuseMarkup, transaction) { - !(container && (container.nodeType === ELEMENT_NODE_TYPE || container.nodeType === DOC_NODE_TYPE || container.nodeType === DOCUMENT_FRAGMENT_NODE_TYPE)) ? process.env.NODE_ENV !== 'production' ? invariant(false, 'mountComponentIntoNode(...): Target container is not valid.') : invariant(false) : undefined; - - if (shouldReuseMarkup) { - var rootElement = getReactRootElementInContainer(container); - if (ReactMarkupChecksum.canReuseMarkup(markup, rootElement)) { - return; - } else { - var checksum = rootElement.getAttribute(ReactMarkupChecksum.CHECKSUM_ATTR_NAME); - rootElement.removeAttribute(ReactMarkupChecksum.CHECKSUM_ATTR_NAME); - - var rootMarkup = rootElement.outerHTML; - rootElement.setAttribute(ReactMarkupChecksum.CHECKSUM_ATTR_NAME, checksum); - - var normalizedMarkup = markup; - if (process.env.NODE_ENV !== 'production') { - // because rootMarkup is retrieved from the DOM, various normalizations - // will have occurred which will not be present in `markup`. Here, - // insert markup into a

or