Skip to content

File tif.cpp

File List > lib > tif > tif.cpp

Go to the documentation of this file

#include "tif.hpp"

#include <gdal.h>

#include <algorithm>
#include <optional>
#include <type_traits>

#include "assert/gdal_assert.hpp"
#include "gdal_priv.h"
#include "io/crs.hpp"
#include "io/gdal_init.hpp"
#include "isom/colors.hpp"

template <typename T>
struct is_std_optional : std::false_type {};

template <typename T>
struct is_std_optional<std::optional<T>> : std::true_type {};

template <typename T>
inline constexpr bool IS_STD_OPTIONAL_V = is_std_optional<T>::value;

template <typename T>
constexpr GDALDataType gdal_type() {
  if constexpr (std::is_same_v<double, T>) {
    return GDT_Float64;
  } else if constexpr (std::is_same_v<float, T>) {
    return GDT_Float32;
  } else if constexpr (std::is_same_v<unsigned int, T>) {
    return GDT_UInt32;
  } else if constexpr (std::is_same_v<std::byte, T>) {
    return GDT_Byte;
  } else if constexpr (IS_STD_OPTIONAL_V<T>) {
    return gdal_type<typename T::value_type>();
  } else if constexpr (std::is_base_of_v<Color, T>) {
    return GDT_Byte;
  } else {
    static_assert(std::is_base_of_v<Color, T>);
  }
}

Geo<MultiBand<FlexGrid>> read_tif(const fs::path& filename, ProgressTracker&& progress_tracker) {
  START_TRACKER("reading " + filename.filename().string());
  Assert(fs::exists(filename), "File " + filename.string() + " does not seem to exist");
  ensure_gdal_initialized();
  GDALDataset* dataset = (GDALDataset*)GDALOpen(filename.string().c_str(), GA_ReadOnly);
  if (dataset == nullptr) {
    Fail("Could not open file " + filename.string());
  }
  size_t width = dataset->GetRasterXSize();
  size_t height = dataset->GetRasterYSize();
  int bands = dataset->GetRasterCount();
  GDALDataType datatype = dataset->GetRasterBand(1)->GetRasterDataType();
  unsigned int n_bytes = GDALGetDataTypeSizeBytes(datatype);
  GeoTransform transform(*dataset);
  GeoProjection projection = make_projection_from_wkt(std::string(dataset->GetProjectionRef()));

  Geo<MultiBand<FlexGrid>> result(std::move(transform), std::move(projection), bands, width, height,
                                  n_bytes, datatype);
  for (int band = 0; band < bands; band++) {
    GDALRasterBand* raster_band = dataset->GetRasterBand(band + 1);
    AssertEQ(raster_band->GetRasterDataType(), datatype);
    GDALAssert(raster_band->RasterIO(GF_Read, 0, 0, width, height, result[band].data(), width,
                                     height, datatype, 0, 0));
  }
  GDALClose(dataset);
  return result;
}

template <typename GridT>
void write_to_tif(const Geo<GridT>& grid, const fs::path& filename,
                  ProgressTracker&& progress_tracker, const bool include_vertical_crs) {
  START_TRACKER("writing " + filename.filename().string());
  ensure_gdal_initialized();

  int bands;
  GDALDataType datatype;

  if constexpr (std::is_same_v<GridT, MultiBand<FlexGrid>>) {
    bands = grid.size();
    datatype = (GDALDataType)grid[0].data_type();
  } else {
    using T = typename GridT::value_type;
    bands = IS_STD_OPTIONAL_V<T> ? 2 : std::is_base_of_v<Color, T> ? 3 : 1;
    datatype = gdal_type<T>();
  }

  char** options = nullptr;
  options = CSLSetNameValue(options, "COMPRESS", "LZW");
  options = CSLSetNameValue(options, "NUM_THREADS", "8");

  // Only set ALPHA=YES for types that actually have a transparency band.
  // Optional types (2 bands: data + alpha) have alpha; Color types (3 bands:
  // RGB) and scalar types (1 band) do not.
  if constexpr (!std::is_same_v<GridT, MultiBand<FlexGrid>>) {
    using T = typename GridT::value_type;
    if constexpr (IS_STD_OPTIONAL_V<T>) {
      options = CSLSetNameValue(options, "ALPHA", "YES");
    }
  }

  uint64_t estimated_size =
      (uint64_t)grid.width() * grid.height() * bands * GDALGetDataTypeSizeBytes(datatype);
  if (estimated_size > 4000000000ULL) {
    options = CSLSetNameValue(options, "BIGTIFF", "YES");
  } else {
    options = CSLSetNameValue(options, "BIGTIFF", "IF_NEEDED");
  }

  if (grid.width() == 0 || grid.height() == 0) {
    std::cerr << "Warning: skipping TIF write for empty grid: " << filename.string() << "\n";
    CSLDestroy(options);
    return;
  }

  GDALDriver* driver = GetGDALDriverManager()->GetDriverByName("GTiff");
  //
  GDALDataset* dataset = driver->Create(filename.string().c_str(), grid.width(), grid.height(),
                                        bands, datatype, options);

  if (dataset == nullptr) {
    Fail("Could not create file " + filename.string());
  }

  dataset->SetGeoTransform(const_cast<double*>(grid.transform().get_raw()));
  const std::string& projection_wkt =
      include_vertical_crs ? grid.projection().compound_wkt() : grid.projection().to_string();
  dataset->SetProjection(projection_wkt.c_str());

  if constexpr (std::is_same_v<GridT, MultiBand<FlexGrid>>) {
    for (unsigned int band = 0; band < grid.size(); band++) {
      GDALAssert(dataset->GetRasterBand(band + 1)->RasterIO(
          GF_Write, 0, 0, grid.width(), grid.height(), const_cast<std::byte*>(grid[band].data()),
          grid.width(), grid.height(), datatype, 0, 0));
    }
  } else {
    using T = typename GridT::value_type;
    if constexpr (IS_STD_OPTIONAL_V<T>) {
      for (size_t i = 0; i < grid.height(); i++) {
        std::vector<typename T::value_type> data(grid.width());
        std::vector<typename T::value_type> transparent(grid.width());
        for (size_t j = 0; j < grid.width(); j++) {
          data[j] = grid[{j, i}].has_value() ? grid[{j, i}].value() : typename T::value_type(0);
          transparent[j] =
              grid[{j, i}].has_value() ? typename T::value_type(255) : typename T::value_type(0);
        }

        GDALAssert(dataset->GetRasterBand(1)->RasterIO(GF_Write, 0, i, grid.width(), 1, data.data(),
                                                       grid.width(), 1, datatype, 0, 0));
        GDALAssert(dataset->GetRasterBand(2)->RasterIO(
            GF_Write, 0, i, grid.width(), 1, transparent.data(), grid.width(), 1, datatype, 0, 0));
      }
    } else if constexpr (std::is_base_of_v<Color, T>) {
      for (int band = 0; band < 3; band++) {
        std::vector<unsigned char> data(grid.width() * grid.height());
#pragma omp parallel for
        for (size_t i = 0; i < grid.height(); i++) {
          for (size_t j = 0; j < grid.width(); j++) {
            data[i * grid.width() + j] = grid[{j, i}].toRGB()[band];
          }
        }
        GDALAssert(dataset->GetRasterBand(band + 1)->RasterIO(
            GF_Write, 0, 0, grid.width(), grid.height(), data.data(), grid.width(), grid.height(),
            datatype, 0, 0));
      }
    } else {
      GDALAssert(dataset->GetRasterBand(1)->RasterIO(GF_Write, 0, 0, grid.width(), grid.height(),
                                                     const_cast<T*>(&grid[{0, 0}]), grid.width(),
                                                     grid.height(), datatype, 0, 0));
    }
  }

  GDALClose(dataset);
  CSLDestroy(options);
  ;
}

template void write_to_tif(const GeoGrid<double>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const GeoGrid<float>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const GeoGrid<std::byte>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const GeoGrid<RGBColor>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const GeoGrid<std::optional<std::byte>>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const GeoGrid<std::optional<double>>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const GeoGrid<std::optional<float>>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);
template void write_to_tif(const Geo<MultiBand<FlexGrid>>& grid, const fs::path& filename,
                           ProgressTracker&& progress_tracker, bool include_vertical_crs);

template <typename T>
void write_to_image_tif(const GeoGrid<T>& grid, const fs::path& filename,
                        ProgressTracker&& progress_tracker, std::optional<T> min_val,
                        std::optional<T> max_val) {
  START_TRACKER("writing image " + filename.filename().string());
  if (grid.width() == 0 || grid.height() == 0) {
    std::cerr << "Warning: skipping TIF write for empty grid: " + filename.string() << "\n";
    return;
  }
  GeoGrid<std::byte> result(grid.width(), grid.height(), GeoTransform(grid.transform()),
                            GeoProjection(grid.projection()));
  T min = min_val.value_or(grid.min_value());
  T max = max_val.value_or(grid.max_value());
#pragma omp parallel for
  for (size_t i = 0; i < grid.height(); i++) {
    for (size_t j = 0; j < grid.width(); j++) {
      if constexpr (std::is_same_v<T, bool>) {
        result[{j, i}] = grid[{j, i}] ? std::byte(255) : std::byte(0);
      } else {
        const double denom = static_cast<double>(max - min);
        const double normalized = (denom == 0.0) ? 0.0 : (255.0 * (grid[{j, i}] - min) / denom);
        result[{j, i}] = static_cast<std::byte>(std::clamp(normalized, 0.0, 255.0));
      }
    }
  }
  CMYKColor cmyk = CMYKColor::FromRGB(RGBColor(255, 255, 255));
  (void)cmyk;
  write_to_tif(result, filename, SUBTRACKER(0.5, 1.0, progress_tracker));
}

template void write_to_image_tif(const GeoGrid<double>& grid, const fs::path& filename,
                                 ProgressTracker&& progress_tracker, std::optional<double> min_val,
                                 std::optional<double> max_val);